From f6e7a4e68ee458d9f655b82bb2e3ff0c2e04398b Mon Sep 17 00:00:00 2001 From: clarkzinzow Date: Fri, 1 Dec 2023 16:54:40 -0800 Subject: [PATCH 1/4] Add broadcast join. --- daft/context.py | 8 +- daft/daft.pyi | 5 +- daft/dataframe/dataframe.py | 30 +++- daft/execution/execution_step.py | 11 +- daft/execution/physical_plan.py | 128 +++++++++++++++++- daft/execution/rust_physical_plan_shim.py | 24 +++- daft/io/file_path.py | 3 + daft/logical/builder.py | 6 +- daft/runners/partitioning.py | 7 + daft/runners/pyrunner.py | 10 +- daft/runners/ray_runner.py | 55 +++++--- src/common/daft-config/src/lib.rs | 8 +- src/common/daft-config/src/python.rs | 15 +- src/daft-micropartition/src/ops/join.rs | 1 + src/daft-plan/src/builder.rs | 4 + .../src/physical_ops/broadcast_join.rs | 37 +++++ .../physical_ops/{join.rs => hash_join.rs} | 4 +- src/daft-plan/src/physical_ops/mod.rs | 6 +- src/daft-plan/src/physical_plan.rs | 65 ++++++++- src/daft-plan/src/planner.rs | 35 ++++- src/daft-plan/src/source_info/mod.rs | 3 + src/daft-table/src/ops/joins/hash_join.rs | 1 + tests/dataframe/test_joins.py | 24 +++- 23 files changed, 434 insertions(+), 56 deletions(-) create mode 100644 src/daft-plan/src/physical_ops/broadcast_join.rs rename src/daft-plan/src/physical_ops/{join.rs => hash_join.rs} (94%) diff --git a/daft/context.py b/daft/context.py index 9954b29260..d7eb957b08 100644 --- a/daft/context.py +++ b/daft/context.py @@ -195,6 +195,7 @@ def set_config( config: PyDaftConfig | None = None, merge_scan_tasks_min_size_bytes: int | None = None, merge_scan_tasks_max_size_bytes: int | None = None, + broadcast_join_size_bytes_threshold: int | None = None, ) -> DaftContext: """Globally sets various configuration parameters which control various aspects of Daft execution @@ -203,10 +204,12 @@ def set_config( that the old (current) config should be used. merge_scan_tasks_min_size_bytes: Minimum size in bytes when merging ScanTasks when reading files from storage. Increasing this value will make Daft perform more merging of files into a single partition before yielding, - which leads to bigger but fewer partitions. (Defaults to 64MB) + which leads to bigger but fewer partitions. (Defaults to 64 MiB) merge_scan_tasks_max_size_bytes: Maximum size in bytes when merging ScanTasks when reading files from storage. Increasing this value will increase the upper bound of the size of merged ScanTasks, which leads to bigger but - fewer partitions. (Defaults to 512MB) + fewer partitions. (Defaults to 512 MiB) + broadcast_join_size_bytes_threshold: If one side of a join is smaller than this threshold, a broadcast join will be used. + Default is 10 MiB. """ ctx = get_context() if ctx.disallow_set_runner: @@ -220,6 +223,7 @@ def set_config( new_daft_config = old_daft_config.with_config_values( merge_scan_tasks_min_size_bytes=merge_scan_tasks_min_size_bytes, merge_scan_tasks_max_size_bytes=merge_scan_tasks_max_size_bytes, + broadcast_join_size_bytes_threshold=broadcast_join_size_bytes_threshold, ) ctx.daft_config = new_daft_config diff --git a/daft/daft.pyi b/daft/daft.pyi index 47240b6019..7fe6f6a4b4 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -1058,7 +1058,7 @@ class LogicalPlanBuilder: @staticmethod def in_memory_scan( - partition_key: str, cache_entry: PartitionCacheEntry, schema: PySchema, num_partitions: int + partition_key: str, cache_entry: PartitionCacheEntry, schema: PySchema, num_partitions: int, size_bytes: int ) -> LogicalPlanBuilder: ... @staticmethod def table_scan_with_scan_operator(scan_operator: ScanOperatorHandle) -> LogicalPlanBuilder: ... @@ -1102,11 +1102,14 @@ class PyDaftConfig: self, merge_scan_tasks_min_size_bytes: int | None = None, merge_scan_tasks_max_size_bytes: int | None = None, + broadcast_join_size_bytes_threshold: int | None = None, ) -> PyDaftConfig: ... @property def merge_scan_tasks_min_size_bytes(self): ... @property def merge_scan_tasks_max_size_bytes(self): ... + @property + def broadcast_join_size_bytes_threshold(self): ... def build_type() -> str: ... def version() -> str: ... diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index ebcd8ddcc2..c4b729e039 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -85,10 +85,16 @@ def _builder(self) -> LogicalPlanBuilder: return self.__builder else: num_partitions = self._result_cache.num_partitions() + size_bytes = self._result_cache.size_bytes() # Partition set should always be set on cache entry. - assert num_partitions is not None, "Partition set should always be set on cache entry" + assert ( + num_partitions is not None and size_bytes is not None + ), "Partition set should always be set on cache entry" return self.__builder.from_in_memory_scan( - self._result_cache, self.__builder.schema(), num_partitions=num_partitions + self._result_cache, + self.__builder.schema(), + num_partitions=num_partitions, + size_bytes=size_bytes, ) def _get_current_builder(self) -> LogicalPlanBuilder: @@ -273,7 +279,11 @@ def _from_tables(cls, *parts: MicroPartition) -> "DataFrame": context = get_context() cache_entry = context.runner().put_partition_set_into_cache(result_pset) - builder = LogicalPlanBuilder.from_in_memory_scan(cache_entry, parts[0].schema(), result_pset.num_partitions()) + size_bytes = result_pset.size_bytes() + assert size_bytes is not None, "In-memory data should always have non-None size in bytes" + builder = LogicalPlanBuilder.from_in_memory_scan( + cache_entry, parts[0].schema(), result_pset.num_partitions(), size_bytes + ) return cls(builder) ### @@ -1233,8 +1243,13 @@ def _from_ray_dataset(cls, ds: "RayDataset") -> "DataFrame": partition_set, schema = ray_runner_io.partition_set_from_ray_dataset(ds) cache_entry = context.runner().put_partition_set_into_cache(partition_set) + size_bytes = partition_set.size_bytes() + assert size_bytes is not None, "In-memory data should always have non-None size in bytes" builder = LogicalPlanBuilder.from_in_memory_scan( - cache_entry, schema=schema, num_partitions=partition_set.num_partitions() + cache_entry, + schema=schema, + num_partitions=partition_set.num_partitions(), + size_bytes=size_bytes, ) return cls(builder) @@ -1300,8 +1315,13 @@ def _from_dask_dataframe(cls, ddf: "dask.DataFrame") -> "DataFrame": partition_set, schema = ray_runner_io.partition_set_from_dask_dataframe(ddf) cache_entry = context.runner().put_partition_set_into_cache(partition_set) + size_bytes = partition_set.size_bytes() + assert size_bytes is not None, "In-memory data should always have non-None size in bytes" builder = LogicalPlanBuilder.from_in_memory_scan( - cache_entry, schema=schema, num_partitions=partition_set.num_partitions() + cache_entry, + schema=schema, + num_partitions=partition_set.num_partitions(), + size_bytes=size_bytes, ) return cls(builder) diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index 672aa0def7..f8eabe3e9e 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -629,12 +629,21 @@ class Join(SingleOutputInstruction): left_on: ExpressionsProjection right_on: ExpressionsProjection how: JoinType + is_swapped: bool def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]: return self._join(inputs) def _join(self, inputs: list[MicroPartition]) -> list[MicroPartition]: - [left, right] = inputs + *lefts, right = inputs + if len(lefts) > 1: + left = MicroPartition.concat(lefts) + else: + left = lefts[0] + if self.is_swapped: + # Swap left/right back. + # We don't need to swap left_on and right_on since those were never swapped in the first place. + left, right = right, left result = left.join( right, left_on=self.left_on, diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index a25f3511b4..e5862cee74 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -186,7 +186,7 @@ def pipeline_instruction( ) -def join( +def hash_join( left_plan: InProgressPhysicalPlan[PartitionT], right_plan: InProgressPhysicalPlan[PartitionT], left_on: ExpressionsProjection, @@ -233,6 +233,7 @@ def join( left_on=left_on, right_on=right_on, how=how, + is_swapped=False, ) ) yield join_step @@ -271,6 +272,131 @@ def join( return +def broadcast_join( + broadcaster_plan: InProgressPhysicalPlan[PartitionT], + reciever_plan: InProgressPhysicalPlan[PartitionT], + left_on: ExpressionsProjection, + right_on: ExpressionsProjection, + how: JoinType, + is_swapped: bool, +) -> InProgressPhysicalPlan[PartitionT]: + """Broadcast join all partitions from the broadcaster child plan to each partition in the receiver child plan.""" + + def _emit_join_step( + broadcaster_parts: deque[SingleOutputPartitionTask[PartitionT]], + receiver_part: SingleOutputPartitionTask[PartitionT], + ) -> PartitionTaskBuilder[PartitionT]: + # Calculate memory request for task. + broadcaster_size_bytes_ = 0 + broadcaster_partitions = [] + broadcaster_partition_metadatas = [] + null_count = 0 + for next_broadcaster in broadcaster_parts: + next_broadcaster_partition_metadata = next_broadcaster.partition_metadata() + if next_broadcaster_partition_metadata is None or next_broadcaster_partition_metadata.size_bytes is None: + null_count += 1 + else: + broadcaster_size_bytes_ += next_broadcaster_partition_metadata.size_bytes + broadcaster_partitions.append(next_broadcaster.partition()) + broadcaster_partition_metadatas.append(next_broadcaster_partition_metadata) + if null_count == len(broadcaster_parts): + broadcaster_size_bytes = None + elif null_count > 0: + # Impute null size estimates with mean of non-null estimates. + broadcaster_size_bytes = broadcaster_size_bytes_ + math.ceil( + null_count * broadcaster_size_bytes_ / (len(broadcaster_parts) - null_count) + ) + else: + broadcaster_size_bytes = broadcaster_size_bytes_ + receiver_size_bytes = receiver_part.partition_metadata().size_bytes + if broadcaster_size_bytes is None and receiver_size_bytes is None: + size_bytes = None + elif broadcaster_size_bytes is None and receiver_size_bytes is not None: + # Use 1.25x the receiver side as the memory request, assuming that receiver side is ~4x larger than the broadcaster side. + size_bytes = int(1.25 * receiver_size_bytes) + elif receiver_size_bytes is None and broadcaster_size_bytes is not None: + # Use 4x the broadcaster side as the memory request, assuming that receiver side is ~4x larger than the broadcaster side. + size_bytes = 4 * broadcaster_size_bytes + elif broadcaster_size_bytes is not None and receiver_size_bytes is not None: + size_bytes = broadcaster_size_bytes + receiver_size_bytes + + return PartitionTaskBuilder[PartitionT]( + inputs=broadcaster_partitions + [receiver_part.partition()], + partial_metadatas=list(broadcaster_partition_metadatas + [receiver_part.partition_metadata()]), + resource_request=ResourceRequest(memory_bytes=size_bytes), + ).add_instruction( + instruction=execution_step.Join( + left_on=left_on, + right_on=right_on, + how=how, + is_swapped=is_swapped, + ) + ) + + # Materialize the steps from the broadcaster and receiver sources to get partitions. + # As the receiver-side materializations complete, emit new steps to join each broadcaster and receiver partition. + stage_id = next(stage_id_counter) + broadcaster_requests: deque[SingleOutputPartitionTask[PartitionT]] = deque() + broadcaster_parts: deque[SingleOutputPartitionTask[PartitionT]] = deque() + + # First, fully materialize the broadcasting side (broadcaster side) of the join. + while True: + # Moved completed partition tasks in the broadcaster side of the join to the materialized partition set. + while broadcaster_requests and broadcaster_requests[0].done(): + broadcaster_parts.append(broadcaster_requests.popleft()) + + # Execute single child step to pull in more broadcaster-side partitions. + try: + step = next(broadcaster_plan) + if isinstance(step, PartitionTaskBuilder): + step = step.finalize_partition_task_single_output(stage_id=stage_id) + broadcaster_requests.append(step) + yield step + except StopIteration: + if broadcaster_requests: + logger.debug( + "broadcast join blocked on completion of broadcasting side of join.\n broadcaster sources: %s", + broadcaster_requests, + ) + yield None + else: + break + + # Second, broadcast materialized partitions to receiver side of join, as it materializes. + receiver_requests: deque[SingleOutputPartitionTask[PartitionT]] = deque() + receiver_parts: deque[SingleOutputPartitionTask[PartitionT]] = deque() + + while True: + # Moved completed partition tasks in the receiver side of the join to the materialized partition set. + while receiver_requests and receiver_requests[0].done(): + receiver_parts.append(receiver_requests.popleft()) + + # Emit join steps for newly materialized partitions. + # Broadcast all broadcaster partitions to each new receiver partition that was materialized on this dispatch loop. + for receiver_part in receiver_parts: + yield _emit_join_step(broadcaster_parts, receiver_part) + + # We always emit a join step for every newly materialized receiver-side partition, so clear the materialized partition queue. + receiver_parts.clear() + + # Execute single child step to pull in more input partitions. + try: + step = next(reciever_plan) + if isinstance(step, PartitionTaskBuilder): + step = step.finalize_partition_task_single_output(stage_id=stage_id) + receiver_requests.append(step) + yield step + except StopIteration: + if receiver_requests: + logger.debug( + "broadcast join blocked on completion of receiver side of join.\n receiver sources: %s", + receiver_requests, + ) + yield None + else: + return + + def concat( top_plan: InProgressPhysicalPlan[PartitionT], bottom_plan: InProgressPhysicalPlan[PartitionT] ) -> InProgressPhysicalPlan[PartitionT]: diff --git a/daft/execution/rust_physical_plan_shim.py b/daft/execution/rust_physical_plan_shim.py index 4e3644523d..d2193dd8bd 100644 --- a/daft/execution/rust_physical_plan_shim.py +++ b/daft/execution/rust_physical_plan_shim.py @@ -188,7 +188,7 @@ def reduce_merge( return physical_plan.reduce(input, reduce_instruction) -def join( +def hash_join( input: physical_plan.InProgressPhysicalPlan[PartitionT], right: physical_plan.InProgressPhysicalPlan[PartitionT], left_on: list[PyExpr], @@ -197,7 +197,7 @@ def join( ) -> physical_plan.InProgressPhysicalPlan[PartitionT]: left_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in left_on]) right_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in right_on]) - return physical_plan.join( + return physical_plan.hash_join( left_plan=input, right_plan=right, left_on=left_on_expr_proj, @@ -206,6 +206,26 @@ def join( ) +def broadcast_join( + input: physical_plan.InProgressPhysicalPlan[PartitionT], + right: physical_plan.InProgressPhysicalPlan[PartitionT], + left_on: list[PyExpr], + right_on: list[PyExpr], + join_type: JoinType, + is_swapped: bool, +) -> physical_plan.InProgressPhysicalPlan[PartitionT]: + left_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in left_on]) + right_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in right_on]) + return physical_plan.broadcast_join( + broadcaster_plan=input, + reciever_plan=right, + left_on=left_on_expr_proj, + right_on=right_on_expr_proj, + how=join_type, + is_swapped=is_swapped, + ) + + def write_file( input: physical_plan.InProgressPhysicalPlan[PartitionT], file_format: FileFormat, diff --git a/daft/io/file_path.py b/daft/io/file_path.py index 079fa9ef5a..7fb1905da5 100644 --- a/daft/io/file_path.py +++ b/daft/io/file_path.py @@ -48,9 +48,12 @@ def from_glob_path(path: str, io_config: Optional[IOConfig] = None) -> DataFrame file_infos_table = MicroPartition._from_pytable(file_infos.to_table()) partition = LocalPartitionSet({0: file_infos_table}) cache_entry = context.runner().put_partition_set_into_cache(partition) + size_bytes = partition.size_bytes() + assert size_bytes is not None, "In-memory data should always have non-None size in bytes" builder = LogicalPlanBuilder.from_in_memory_scan( cache_entry, schema=file_infos_table.schema(), num_partitions=partition.num_partitions(), + size_bytes=size_bytes, ) return DataFrame(builder) diff --git a/daft/logical/builder.py b/daft/logical/builder.py index f948f666a6..5f9f59dc08 100644 --- a/daft/logical/builder.py +++ b/daft/logical/builder.py @@ -74,9 +74,11 @@ def optimize(self) -> LogicalPlanBuilder: @classmethod def from_in_memory_scan( - cls, partition: PartitionCacheEntry, schema: Schema, num_partitions: int + cls, partition: PartitionCacheEntry, schema: Schema, num_partitions: int, size_bytes: int ) -> LogicalPlanBuilder: - builder = _LogicalPlanBuilder.in_memory_scan(partition.key, partition, schema._schema, num_partitions) + builder = _LogicalPlanBuilder.in_memory_scan( + partition.key, partition, schema._schema, num_partitions, size_bytes + ) return cls(builder) @classmethod diff --git a/daft/runners/partitioning.py b/daft/runners/partitioning.py index ec6c04085c..45e86c0d54 100644 --- a/daft/runners/partitioning.py +++ b/daft/runners/partitioning.py @@ -175,6 +175,10 @@ def has_partition(self, idx: PartID) -> bool: def __len__(self) -> int: raise NotImplementedError() + @abstractmethod + def size_bytes(self) -> int | None: + raise NotImplementedError() + @abstractmethod def num_partitions(self) -> int: raise NotImplementedError() @@ -208,6 +212,9 @@ def __setstate__(self, key): def num_partitions(self) -> int | None: return self.value.num_partitions() if self.value is not None else None + def size_bytes(self) -> int | None: + return self.value.size_bytes() if self.value is not None else None + class PartitionSetCache: def __init__(self) -> None: diff --git a/daft/runners/pyrunner.py b/daft/runners/pyrunner.py index 55076c2156..34c9d5e599 100644 --- a/daft/runners/pyrunner.py +++ b/daft/runners/pyrunner.py @@ -64,7 +64,15 @@ def has_partition(self, idx: PartID) -> bool: return idx in self._partitions def __len__(self) -> int: - return sum([len(self._partitions[pid]) for pid in self._partitions]) + return sum(len(partition) for partition in self._partitions.values()) + + def size_bytes(self) -> int | None: + size_bytes_ = [partition.size_bytes() for partition in self._partitions.values()] + size_bytes: list[int] = [size for size in size_bytes_ if size is not None] + if len(size_bytes) != len(size_bytes_): + return None + else: + return sum(size_bytes) def num_partitions(self) -> int: return len(self._partitions) diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index b845296c7a..06fde27c76 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -212,7 +212,15 @@ def has_partition(self, idx: PartID) -> bool: return idx in self._results def __len__(self) -> int: - return sum([self._results[pid].metadata().num_rows for pid in self._results]) + return sum(result.metadata().num_rows for result in self._results.values()) + + def size_bytes(self) -> int | None: + size_bytes_ = [result.metadata().size_bytes for result in self._results.values()] + size_bytes: list[int] = [size for size in size_bytes_ if size is not None] + if len(size_bytes) != len(size_bytes_): + return None + else: + return sum(size_bytes) def num_partitions(self) -> int: return len(self._results) @@ -302,7 +310,7 @@ def partition_set_from_ray_dataset( RayPartitionSet( _daft_config_objref=self.daft_config_objref, _results={ - i: RayMaterializedResult(obj, _daft_config_objref=self.daft_config_objref) + i: RayMaterializedResult(obj, daft_config_objref=self.daft_config_objref) for i, obj in enumerate(daft_vpartitions) }, ), @@ -335,7 +343,7 @@ def partition_set_from_dask_dataframe( RayPartitionSet( _daft_config_objref=self.daft_config_objref, _results={ - i: RayMaterializedResult(obj, _daft_config_objref=self.daft_config_objref) + i: RayMaterializedResult(obj, daft_config_objref=self.daft_config_objref) for i, obj in enumerate(daft_vpartitions) }, ), @@ -412,9 +420,9 @@ def reduce_and_fanout( @ray.remote -def get_meta(daft_config: PyDaftConfig, partition: MicroPartition) -> PartitionMetadata: +def get_metas(daft_config: PyDaftConfig, *partitions: MicroPartition) -> list[PartitionMetadata]: set_config(daft_config) - return PartitionMetadata.from_table(partition) + return [PartitionMetadata.from_table(partition) for partition in partitions] def _ray_num_cpus_provider(ttl_seconds: int = 1) -> Generator[int, None, None]: @@ -578,7 +586,7 @@ def place_in_queue(item): assert isinstance(next_step, SingleOutputPartitionTask) next_step.set_result( [ - RayMaterializedResult(partition, _daft_config_objref=self.daft_config_objref) + RayMaterializedResult(partition, daft_config_objref=self.daft_config_objref) for partition in next_step.inputs ] ) @@ -617,7 +625,6 @@ def place_in_queue(item): dispatch = datetime.now() completed_task_ids = [] for wait_for in ("next_one", "next_batch"): - if not is_active(): break @@ -704,10 +711,9 @@ def _build_partitions(daft_config_objref: ray.ObjectRef, task: PartitionTask[ray task.set_result( [ RayMaterializedResult( - _partition=partition, - _daft_config_objref=daft_config_objref, - _metadatas=metadatas_accessor, - _metadata_index=i, + partition=partition, + metadatas=metadatas_accessor, + metadata_idx=i, ) for i, partition in enumerate(partitions) ] @@ -824,7 +830,7 @@ def put_partition_set_into_cache(self, pset: PartitionSet) -> PartitionCacheEntr pset = RayPartitionSet( _daft_config_objref=self.daft_config_objref, _results={ - pid: RayMaterializedResult(ray.put(val), _daft_config_objref=self.daft_config_objref) + pid: RayMaterializedResult(ray.put(val), daft_config_objref=self.daft_config_objref) for pid, val in pset._partitions.items() }, ) @@ -835,12 +841,22 @@ def runner_io(self) -> RayRunnerIO: return RayRunnerIO(daft_config_objref=self.daft_config_objref) -@dataclass(frozen=True) class RayMaterializedResult(MaterializedResult[ray.ObjectRef]): - _partition: ray.ObjectRef - _daft_config_objref: ray.ObjectRef - _metadatas: PartitionMetadataAccessor | None = None - _metadata_index: int | None = None + def __init__( + self, + partition: ray.ObjectRef, + daft_config_objref: ray.ObjectRef | None = None, + metadatas: PartitionMetadataAccessor | None = None, + metadata_idx: int | None = None, + ): + self._partition = partition + if metadatas is None: + assert metadata_idx is None + assert daft_config_objref is not None + metadatas = PartitionMetadataAccessor(get_metas.remote(daft_config_objref, self._partition)) + metadata_idx = 0 + self._metadatas = metadatas + self._metadata_idx = metadata_idx def partition(self) -> ray.ObjectRef: return self._partition @@ -849,10 +865,7 @@ def vpartition(self) -> MicroPartition: return ray.get(self._partition) def metadata(self) -> PartitionMetadata: - if self._metadatas is not None and self._metadata_index is not None: - return self._metadatas.get_index(self._metadata_index) - else: - return ray.get(get_meta.remote(self._daft_config_objref, self._partition)) + return self._metadatas.get_index(self._metadata_idx) def cancel(self) -> None: return ray.cancel(self._partition) diff --git a/src/common/daft-config/src/lib.rs b/src/common/daft-config/src/lib.rs index b71280ee36..0a2a0b1429 100644 --- a/src/common/daft-config/src/lib.rs +++ b/src/common/daft-config/src/lib.rs @@ -4,13 +4,17 @@ use serde::{Deserialize, Serialize}; pub struct DaftConfig { pub merge_scan_tasks_min_size_bytes: usize, pub merge_scan_tasks_max_size_bytes: usize, + pub broadcast_join_size_bytes_threshold: usize, } impl Default for DaftConfig { fn default() -> Self { DaftConfig { - merge_scan_tasks_min_size_bytes: 64 * 1024 * 1024, // 64MB - merge_scan_tasks_max_size_bytes: 512 * 1024 * 1024, // 512MB + merge_scan_tasks_min_size_bytes: 64 * 1024 * 1024, // 64 MiB + merge_scan_tasks_max_size_bytes: 512 * 1024 * 1024, // 512 MiB + broadcast_join_size_bytes_threshold: std::env::var("DAFT_BROADCAST_JOIN_THRESHOLD") + .map(|s| s.parse::().unwrap()) + .unwrap_or(10 * 1024 * 1024), // 10 MiB } } } diff --git a/src/common/daft-config/src/python.rs b/src/common/daft-config/src/python.rs index f327e5a55d..353858152c 100644 --- a/src/common/daft-config/src/python.rs +++ b/src/common/daft-config/src/python.rs @@ -19,9 +19,10 @@ impl PyDaftConfig { } fn with_config_values( - &mut self, + &self, merge_scan_tasks_min_size_bytes: Option, merge_scan_tasks_max_size_bytes: Option, + broadcast_join_size_bytes_threshold: Option, ) -> PyResult { let mut config = self.config.as_ref().clone(); @@ -31,22 +32,30 @@ impl PyDaftConfig { if let Some(merge_scan_tasks_min_size_bytes) = merge_scan_tasks_min_size_bytes { config.merge_scan_tasks_min_size_bytes = merge_scan_tasks_min_size_bytes; } + if let Some(broadcast_join_size_bytes_threshold) = broadcast_join_size_bytes_threshold { + config.broadcast_join_size_bytes_threshold = broadcast_join_size_bytes_threshold; + } Ok(PyDaftConfig { config: Arc::new(config), }) } - #[getter(merge_scan_tasks_min_size_bytes)] + #[getter] fn get_merge_scan_tasks_min_size_bytes(&self) -> PyResult { Ok(self.config.merge_scan_tasks_min_size_bytes) } - #[getter(merge_scan_tasks_max_size_bytes)] + #[getter] fn get_merge_scan_tasks_max_size_bytes(&self) -> PyResult { Ok(self.config.merge_scan_tasks_max_size_bytes) } + #[getter] + fn get_broadcast_join_size_bytes_threshold(&self) -> PyResult { + Ok(self.config.broadcast_join_size_bytes_threshold) + } + fn __reduce__(&self, py: Python) -> PyResult<(PyObject, (Vec,))> { let bin_data = bincode::serialize(self.config.as_ref()) .expect("DaftConfig should be serializable to bytes"); diff --git a/src/daft-micropartition/src/ops/join.rs b/src/daft-micropartition/src/ops/join.rs index f8d5efbe5f..86338a2380 100644 --- a/src/daft-micropartition/src/ops/join.rs +++ b/src/daft-micropartition/src/ops/join.rs @@ -41,6 +41,7 @@ impl MicroPartition { return Ok(Self::empty(Some(join_schema.into()))); } + // TODO(Clark): Elide concatenations where possible by doing a chunk-aware local table join. let lt = self.concat_or_get(io_stats.clone())?; let rt = right.concat_or_get(io_stats)?; diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index 9e109cde00..874871c4fc 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -57,12 +57,14 @@ impl LogicalPlanBuilder { cache_entry: PyObject, schema: Arc, num_partitions: usize, + size_bytes: usize, ) -> DaftResult { let source_info = SourceInfo::InMemoryInfo(InMemoryInfo::new( schema.clone(), partition_key.into(), cache_entry, num_partitions, + size_bytes, )); let logical_plan: LogicalPlan = logical_ops::Source::new(schema.clone(), source_info.into()).into(); @@ -285,12 +287,14 @@ impl PyLogicalPlanBuilder { cache_entry: &PyAny, schema: PySchema, num_partitions: usize, + size_bytes: usize, ) -> PyResult { Ok(LogicalPlanBuilder::in_memory_scan( partition_key, cache_entry.to_object(cache_entry.py()), schema.into(), num_partitions, + size_bytes, )? .into()) } diff --git a/src/daft-plan/src/physical_ops/broadcast_join.rs b/src/daft-plan/src/physical_ops/broadcast_join.rs new file mode 100644 index 0000000000..06ab9fd5d0 --- /dev/null +++ b/src/daft-plan/src/physical_ops/broadcast_join.rs @@ -0,0 +1,37 @@ +use std::sync::Arc; + +use daft_dsl::Expr; + +use crate::{physical_plan::PhysicalPlan, JoinType}; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct BroadcastJoin { + // Upstream node. + pub left: Arc, + pub right: Arc, + pub left_on: Vec, + pub right_on: Vec, + pub join_type: JoinType, + pub is_swapped: bool, +} + +impl BroadcastJoin { + pub(crate) fn new( + left: Arc, + right: Arc, + left_on: Vec, + right_on: Vec, + join_type: JoinType, + is_swapped: bool, + ) -> Self { + Self { + left, + right, + left_on, + right_on, + join_type, + is_swapped, + } + } +} diff --git a/src/daft-plan/src/physical_ops/join.rs b/src/daft-plan/src/physical_ops/hash_join.rs similarity index 94% rename from src/daft-plan/src/physical_ops/join.rs rename to src/daft-plan/src/physical_ops/hash_join.rs index 79db5097aa..adfd546e60 100644 --- a/src/daft-plan/src/physical_ops/join.rs +++ b/src/daft-plan/src/physical_ops/hash_join.rs @@ -6,7 +6,7 @@ use crate::{physical_plan::PhysicalPlan, JoinType}; use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct Join { +pub struct HashJoin { // Upstream node. pub left: Arc, pub right: Arc, @@ -15,7 +15,7 @@ pub struct Join { pub join_type: JoinType, } -impl Join { +impl HashJoin { pub(crate) fn new( left: Arc, right: Arc, diff --git a/src/daft-plan/src/physical_ops/mod.rs b/src/daft-plan/src/physical_ops/mod.rs index 09dc707511..e55aca8dd0 100644 --- a/src/daft-plan/src/physical_ops/mod.rs +++ b/src/daft-plan/src/physical_ops/mod.rs @@ -1,4 +1,5 @@ mod agg; +mod broadcast_join; mod coalesce; mod concat; mod csv; @@ -6,9 +7,9 @@ mod explode; mod fanout; mod filter; mod flatten; +mod hash_join; #[cfg(feature = "python")] mod in_memory; -mod join; mod json; mod limit; mod parquet; @@ -19,6 +20,7 @@ mod sort; mod split; pub use agg::Aggregate; +pub use broadcast_join::BroadcastJoin; pub use coalesce::Coalesce; pub use concat::Concat; pub use csv::{TabularScanCsv, TabularWriteCsv}; @@ -26,9 +28,9 @@ pub use explode::Explode; pub use fanout::{FanoutByHash, FanoutByRange, FanoutRandom}; pub use filter::Filter; pub use flatten::Flatten; +pub use hash_join::HashJoin; #[cfg(feature = "python")] pub use in_memory::InMemoryScan; -pub use join::Join; pub use json::{TabularScanJson, TabularWriteJson}; pub use limit::Limit; pub use parquet::{TabularScanParquet, TabularWriteParquet}; diff --git a/src/daft-plan/src/physical_plan.rs b/src/daft-plan/src/physical_plan.rs index 06f0998286..11d2678600 100644 --- a/src/daft-plan/src/physical_plan.rs +++ b/src/daft-plan/src/physical_plan.rs @@ -52,7 +52,8 @@ pub enum PhysicalPlan { ReduceMerge(ReduceMerge), Aggregate(Aggregate), Concat(Concat), - Join(Join), + HashJoin(HashJoin), + BroadcastJoin(BroadcastJoin), TabularWriteParquet(TabularWriteParquet), TabularWriteJson(TabularWriteJson), TabularWriteCsv(TabularWriteCsv), @@ -135,7 +136,7 @@ impl PhysicalPlan { None, ) .into(), - Self::Join(Join { + Self::HashJoin(HashJoin { left, right, left_on, @@ -159,11 +160,36 @@ impl PhysicalPlan { .into(), } } + Self::BroadcastJoin(BroadcastJoin { right, .. }) => right.partition_spec(), Self::TabularWriteParquet(TabularWriteParquet { input, .. }) => input.partition_spec(), Self::TabularWriteCsv(TabularWriteCsv { input, .. }) => input.partition_spec(), Self::TabularWriteJson(TabularWriteJson { input, .. }) => input.partition_spec(), } } + + pub fn approximate_size_bytes(&self) -> Option { + match self { + #[cfg(feature = "python")] + Self::InMemoryScan(InMemoryScan { in_memory_info, .. }) => { + Some(in_memory_info.size_bytes) + } + Self::TabularScan(TabularScan { scan_tasks, .. }) => scan_tasks + .iter() + .map(|scan_task| scan_task.size_bytes()) + .sum::>(), + // Propagate child approximation for operations that don't affect cardinality. + Self::Coalesce(Coalesce { input, .. }) + | Self::FanoutByHash(FanoutByHash { input, .. }) + | Self::FanoutByRange(FanoutByRange { input, .. }) + | Self::FanoutRandom(FanoutRandom { input, .. }) + | Self::Flatten(Flatten { input, .. }) + | Self::ReduceMerge(ReduceMerge { input, .. }) + | Self::Sort(Sort { input, .. }) + | Self::Split(Split { input, .. }) => input.approximate_size_bytes(), + // TODO(Clark): Add size bytes estimates for other operators. + _ => None, + } + } } /// A work scheduler for physical plans. @@ -582,7 +608,7 @@ impl PhysicalPlan { .call1((upstream_input_iter, upstream_other_iter))?; Ok(py_iter.into()) } - PhysicalPlan::Join(Join { + PhysicalPlan::HashJoin(HashJoin { left, right, left_on, @@ -602,13 +628,44 @@ impl PhysicalPlan { .collect(); let py_iter = py .import(pyo3::intern!(py, "daft.execution.rust_physical_plan_shim"))? - .getattr(pyo3::intern!(py, "join"))? + .getattr(pyo3::intern!(py, "hash_join"))? + .call1(( + upstream_left_iter, + upstream_right_iter, + left_on_pyexprs, + right_on_pyexprs, + *join_type, + ))?; + Ok(py_iter.into()) + } + PhysicalPlan::BroadcastJoin(BroadcastJoin { + left, + right, + left_on, + right_on, + join_type, + is_swapped, + }) => { + let upstream_left_iter = left.to_partition_tasks(py, psets, is_ray_runner)?; + let upstream_right_iter = right.to_partition_tasks(py, psets, is_ray_runner)?; + let left_on_pyexprs: Vec = left_on + .iter() + .map(|expr| PyExpr::from(expr.clone())) + .collect(); + let right_on_pyexprs: Vec = right_on + .iter() + .map(|expr| PyExpr::from(expr.clone())) + .collect(); + let py_iter = py + .import(pyo3::intern!(py, "daft.execution.rust_physical_plan_shim"))? + .getattr(pyo3::intern!(py, "broadcast_join"))? .call1(( upstream_left_iter, upstream_right_iter, left_on_pyexprs, right_on_pyexprs, *join_type, + *is_swapped, ))?; Ok(py_iter.into()) } diff --git a/src/daft-plan/src/planner.rs b/src/daft-plan/src/planner.rs index 9e23a1c627..9390efc048 100644 --- a/src/daft-plan/src/planner.rs +++ b/src/daft-plan/src/planner.rs @@ -487,6 +487,39 @@ pub fn plan(logical_plan: &LogicalPlan, cfg: Arc) -> DaftResult { let mut left_physical = plan(left, cfg.clone())?; let mut right_physical = plan(right, cfg.clone())?; + + // If either the left or right side of the join are very small tables, perform a broadcast join with the + // entire smaller table broadcast to each of the partitions of the larger table. + + // Ensure that the left side of the join is the smaller side. + let mut do_swap = false; + let smaller_size_bytes = match ( + left_physical.approximate_size_bytes(), + right_physical.approximate_size_bytes(), + ) { + (Some(left_size_bytes), Some(right_size_bytes)) => { + if right_size_bytes < left_size_bytes { + do_swap = true; + Some(right_size_bytes) + } else { + Some(left_size_bytes) + } + } + (Some(left_size_bytes), None) => Some(left_size_bytes), + (None, Some(right_size_bytes)) => { + do_swap = true; + Some(right_size_bytes) + } + (None, None) => None, + }; + // If smaller table is under broadcast size threshold, use broadcast join. + if let Some(smaller_size_bytes) = smaller_size_bytes && smaller_size_bytes <= cfg.broadcast_join_size_bytes_threshold { + if do_swap { + // These will get swapped back when doing the actual local joins. + (left_physical, right_physical) = (right_physical, left_physical); + } + return Ok(PhysicalPlan::BroadcastJoin(BroadcastJoin::new(left_physical.into(), right_physical.into(), left_on.clone(), right_on.clone(), *join_type, do_swap))); + } let left_pspec = left_physical.partition_spec(); let right_pspec = right_physical.partition_spec(); let num_partitions = max(left_pspec.num_partitions, right_pspec.num_partitions); @@ -520,7 +553,7 @@ pub fn plan(logical_plan: &LogicalPlan, cfg: Arc) -> DaftResult Self { Self { source_schema, cache_key, cache_entry, num_partitions, + size_bytes, } } } diff --git a/src/daft-table/src/ops/joins/hash_join.rs b/src/daft-table/src/ops/joins/hash_join.rs index 68c3b7dfaf..56b80aedb9 100644 --- a/src/daft-table/src/ops/joins/hash_join.rs +++ b/src/daft-table/src/ops/joins/hash_join.rs @@ -43,6 +43,7 @@ pub(super) fn hash_inner_join(left: &Table, right: &Table) -> DaftResult<(Series )); } + // TODO(Clark): Build the probe table on the smaller table, rather than always building it on the left table. let probe_table = left.to_probe_hash_table()?; let r_hashes = right.hash_rows()?; diff --git a/tests/dataframe/test_joins.py b/tests/dataframe/test_joins.py index 30bbd16a45..a376aacf3f 100644 --- a/tests/dataframe/test_joins.py +++ b/tests/dataframe/test_joins.py @@ -3,13 +3,25 @@ import pyarrow as pa import pytest +import daft from daft.datatype import DataType from daft.errors import ExpressionTypeError from tests.utils import sort_arrow_table +@pytest.fixture(params=[False, True]) +def broadcast_join_enabled(request): + # Toggles between default broadcast join threshold (10 MiB), and a threshold of 0, which disables broadcast joins. + broadcast_threshold = 10 * 1024 * 1024 if request.param else False + old_context = daft.context.pop_context() + try: + yield daft.context.set_config(broadcast_join_size_bytes_threshold=broadcast_threshold) + finally: + daft.context.set_context(old_context) + + @pytest.mark.parametrize("n_partitions", [1, 2, 4]) -def test_multicol_joins(make_df, n_partitions: int): +def test_multicol_joins(broadcast_join_enabled, make_df, n_partitions: int): df = make_df( { "A": [1, 2, 3], @@ -32,7 +44,7 @@ def test_multicol_joins(make_df, n_partitions: int): @pytest.mark.parametrize("n_partitions", [1, 2, 4]) -def test_limit_after_join(make_df, n_partitions: int): +def test_limit_after_join(broadcast_join_enabled, make_df, n_partitions: int): data = { "A": [1, 2, 3], } @@ -59,7 +71,7 @@ def test_limit_after_join(make_df, n_partitions: int): @pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) -def test_inner_join(make_df, repartition_nparts): +def test_inner_join(broadcast_join_enabled, make_df, repartition_nparts): daft_df = make_df( { "id": [1, None, 3], @@ -87,7 +99,7 @@ def test_inner_join(make_df, repartition_nparts): @pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) -def test_inner_join_multikey(make_df, repartition_nparts): +def test_inner_join_multikey(broadcast_join_enabled, make_df, repartition_nparts): daft_df = make_df( { "id": [1, None, None], @@ -118,7 +130,7 @@ def test_inner_join_multikey(make_df, repartition_nparts): @pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) -def test_inner_join_all_null(make_df, repartition_nparts): +def test_inner_join_all_null(broadcast_join_enabled, make_df, repartition_nparts): daft_df = make_df( { "id": [None, None, None], @@ -145,7 +157,7 @@ def test_inner_join_all_null(make_df, repartition_nparts): ) -def test_inner_join_null_type_column(make_df): +def test_inner_join_null_type_column(broadcast_join_enabled, make_df): daft_df = make_df( { "id": [None, None, None], From 404d2f86fc7c9b600fb177c880dd96678061c43e Mon Sep 17 00:00:00 2001 From: clarkzinzow Date: Fri, 8 Dec 2023 12:50:06 -0800 Subject: [PATCH 2/4] PR feedback --- daft/execution/execution_step.py | 2 + daft/execution/physical_plan.py | 118 +++++++++--------- daft/execution/rust_physical_plan_shim.py | 8 +- src/common/daft-config/src/lib.rs | 4 +- .../src/physical_ops/broadcast_join.rs | 12 +- src/daft-plan/src/physical_plan.rs | 8 +- src/daft-plan/src/planner.rs | 17 +-- tests/dataframe/test_joins.py | 2 +- 8 files changed, 85 insertions(+), 86 deletions(-) diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index f8eabe3e9e..cf33efc2a0 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -635,8 +635,10 @@ def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]: return self._join(inputs) def _join(self, inputs: list[MicroPartition]) -> list[MicroPartition]: + # All inputs except for the last are the left side of the join, in order to support left-broadcasted joins. *lefts, right = inputs if len(lefts) > 1: + # NOTE: MicroPartition concats don't concatenate the underlying column arrays, since MicroPartitions are chunked. left = MicroPartition.concat(lefts) else: left = lefts[0] diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index e5862cee74..40868673ec 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -272,9 +272,65 @@ def hash_join( return +def _create_join_step( + broadcaster_parts: deque[SingleOutputPartitionTask[PartitionT]], + receiver_part: SingleOutputPartitionTask[PartitionT], + left_on: ExpressionsProjection, + right_on: ExpressionsProjection, + how: JoinType, + is_swapped: bool, +) -> PartitionTaskBuilder[PartitionT]: + # Calculate memory request for task. + broadcaster_size_bytes_ = 0 + broadcaster_partitions = [] + broadcaster_partition_metadatas = [] + null_count = 0 + for next_broadcaster in broadcaster_parts: + next_broadcaster_partition_metadata = next_broadcaster.partition_metadata() + if next_broadcaster_partition_metadata is None or next_broadcaster_partition_metadata.size_bytes is None: + null_count += 1 + else: + broadcaster_size_bytes_ += next_broadcaster_partition_metadata.size_bytes + broadcaster_partitions.append(next_broadcaster.partition()) + broadcaster_partition_metadatas.append(next_broadcaster_partition_metadata) + if null_count == len(broadcaster_parts): + broadcaster_size_bytes = None + elif null_count > 0: + # Impute null size estimates with mean of non-null estimates. + broadcaster_size_bytes = broadcaster_size_bytes_ + math.ceil( + null_count * broadcaster_size_bytes_ / (len(broadcaster_parts) - null_count) + ) + else: + broadcaster_size_bytes = broadcaster_size_bytes_ + receiver_size_bytes = receiver_part.partition_metadata().size_bytes + if broadcaster_size_bytes is None and receiver_size_bytes is None: + size_bytes = None + elif broadcaster_size_bytes is None and receiver_size_bytes is not None: + # Use 1.25x the receiver side as the memory request, assuming that receiver side is ~4x larger than the broadcaster side. + size_bytes = int(1.25 * receiver_size_bytes) + elif receiver_size_bytes is None and broadcaster_size_bytes is not None: + # Use 4x the broadcaster side as the memory request, assuming that receiver side is ~4x larger than the broadcaster side. + size_bytes = 4 * broadcaster_size_bytes + elif broadcaster_size_bytes is not None and receiver_size_bytes is not None: + size_bytes = broadcaster_size_bytes + receiver_size_bytes + + return PartitionTaskBuilder[PartitionT]( + inputs=broadcaster_partitions + [receiver_part.partition()], + partial_metadatas=list(broadcaster_partition_metadatas + [receiver_part.partition_metadata()]), + resource_request=ResourceRequest(memory_bytes=size_bytes), + ).add_instruction( + instruction=execution_step.Join( + left_on=left_on, + right_on=right_on, + how=how, + is_swapped=is_swapped, + ) + ) + + def broadcast_join( broadcaster_plan: InProgressPhysicalPlan[PartitionT], - reciever_plan: InProgressPhysicalPlan[PartitionT], + receiver_plan: InProgressPhysicalPlan[PartitionT], left_on: ExpressionsProjection, right_on: ExpressionsProjection, how: JoinType, @@ -282,57 +338,6 @@ def broadcast_join( ) -> InProgressPhysicalPlan[PartitionT]: """Broadcast join all partitions from the broadcaster child plan to each partition in the receiver child plan.""" - def _emit_join_step( - broadcaster_parts: deque[SingleOutputPartitionTask[PartitionT]], - receiver_part: SingleOutputPartitionTask[PartitionT], - ) -> PartitionTaskBuilder[PartitionT]: - # Calculate memory request for task. - broadcaster_size_bytes_ = 0 - broadcaster_partitions = [] - broadcaster_partition_metadatas = [] - null_count = 0 - for next_broadcaster in broadcaster_parts: - next_broadcaster_partition_metadata = next_broadcaster.partition_metadata() - if next_broadcaster_partition_metadata is None or next_broadcaster_partition_metadata.size_bytes is None: - null_count += 1 - else: - broadcaster_size_bytes_ += next_broadcaster_partition_metadata.size_bytes - broadcaster_partitions.append(next_broadcaster.partition()) - broadcaster_partition_metadatas.append(next_broadcaster_partition_metadata) - if null_count == len(broadcaster_parts): - broadcaster_size_bytes = None - elif null_count > 0: - # Impute null size estimates with mean of non-null estimates. - broadcaster_size_bytes = broadcaster_size_bytes_ + math.ceil( - null_count * broadcaster_size_bytes_ / (len(broadcaster_parts) - null_count) - ) - else: - broadcaster_size_bytes = broadcaster_size_bytes_ - receiver_size_bytes = receiver_part.partition_metadata().size_bytes - if broadcaster_size_bytes is None and receiver_size_bytes is None: - size_bytes = None - elif broadcaster_size_bytes is None and receiver_size_bytes is not None: - # Use 1.25x the receiver side as the memory request, assuming that receiver side is ~4x larger than the broadcaster side. - size_bytes = int(1.25 * receiver_size_bytes) - elif receiver_size_bytes is None and broadcaster_size_bytes is not None: - # Use 4x the broadcaster side as the memory request, assuming that receiver side is ~4x larger than the broadcaster side. - size_bytes = 4 * broadcaster_size_bytes - elif broadcaster_size_bytes is not None and receiver_size_bytes is not None: - size_bytes = broadcaster_size_bytes + receiver_size_bytes - - return PartitionTaskBuilder[PartitionT]( - inputs=broadcaster_partitions + [receiver_part.partition()], - partial_metadatas=list(broadcaster_partition_metadatas + [receiver_part.partition_metadata()]), - resource_request=ResourceRequest(memory_bytes=size_bytes), - ).add_instruction( - instruction=execution_step.Join( - left_on=left_on, - right_on=right_on, - how=how, - is_swapped=is_swapped, - ) - ) - # Materialize the steps from the broadcaster and receiver sources to get partitions. # As the receiver-side materializations complete, emit new steps to join each broadcaster and receiver partition. stage_id = next(stage_id_counter) @@ -364,9 +369,9 @@ def _emit_join_step( # Second, broadcast materialized partitions to receiver side of join, as it materializes. receiver_requests: deque[SingleOutputPartitionTask[PartitionT]] = deque() - receiver_parts: deque[SingleOutputPartitionTask[PartitionT]] = deque() while True: + receiver_parts: deque[SingleOutputPartitionTask[PartitionT]] = deque() # Moved completed partition tasks in the receiver side of the join to the materialized partition set. while receiver_requests and receiver_requests[0].done(): receiver_parts.append(receiver_requests.popleft()) @@ -374,14 +379,11 @@ def _emit_join_step( # Emit join steps for newly materialized partitions. # Broadcast all broadcaster partitions to each new receiver partition that was materialized on this dispatch loop. for receiver_part in receiver_parts: - yield _emit_join_step(broadcaster_parts, receiver_part) - - # We always emit a join step for every newly materialized receiver-side partition, so clear the materialized partition queue. - receiver_parts.clear() + yield _create_join_step(broadcaster_parts, receiver_part, left_on, right_on, how, is_swapped) # Execute single child step to pull in more input partitions. try: - step = next(reciever_plan) + step = next(receiver_plan) if isinstance(step, PartitionTaskBuilder): step = step.finalize_partition_task_single_output(stage_id=stage_id) receiver_requests.append(step) diff --git a/daft/execution/rust_physical_plan_shim.py b/daft/execution/rust_physical_plan_shim.py index d2193dd8bd..bb1016088a 100644 --- a/daft/execution/rust_physical_plan_shim.py +++ b/daft/execution/rust_physical_plan_shim.py @@ -207,8 +207,8 @@ def hash_join( def broadcast_join( - input: physical_plan.InProgressPhysicalPlan[PartitionT], - right: physical_plan.InProgressPhysicalPlan[PartitionT], + broadcaster: physical_plan.InProgressPhysicalPlan[PartitionT], + receiver: physical_plan.InProgressPhysicalPlan[PartitionT], left_on: list[PyExpr], right_on: list[PyExpr], join_type: JoinType, @@ -217,8 +217,8 @@ def broadcast_join( left_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in left_on]) right_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in right_on]) return physical_plan.broadcast_join( - broadcaster_plan=input, - reciever_plan=right, + broadcaster_plan=broadcaster, + receiver_plan=receiver, left_on=left_on_expr_proj, right_on=right_on_expr_proj, how=join_type, diff --git a/src/common/daft-config/src/lib.rs b/src/common/daft-config/src/lib.rs index 0a2a0b1429..7134f65893 100644 --- a/src/common/daft-config/src/lib.rs +++ b/src/common/daft-config/src/lib.rs @@ -12,9 +12,7 @@ impl Default for DaftConfig { DaftConfig { merge_scan_tasks_min_size_bytes: 64 * 1024 * 1024, // 64 MiB merge_scan_tasks_max_size_bytes: 512 * 1024 * 1024, // 512 MiB - broadcast_join_size_bytes_threshold: std::env::var("DAFT_BROADCAST_JOIN_THRESHOLD") - .map(|s| s.parse::().unwrap()) - .unwrap_or(10 * 1024 * 1024), // 10 MiB + broadcast_join_size_bytes_threshold: 10 * 1024 * 1024, // 10 MiB } } } diff --git a/src/daft-plan/src/physical_ops/broadcast_join.rs b/src/daft-plan/src/physical_ops/broadcast_join.rs index 06ab9fd5d0..4b555439d6 100644 --- a/src/daft-plan/src/physical_ops/broadcast_join.rs +++ b/src/daft-plan/src/physical_ops/broadcast_join.rs @@ -8,8 +8,8 @@ use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Serialize, Deserialize)] pub struct BroadcastJoin { // Upstream node. - pub left: Arc, - pub right: Arc, + pub broadcaster: Arc, + pub receiver: Arc, pub left_on: Vec, pub right_on: Vec, pub join_type: JoinType, @@ -18,16 +18,16 @@ pub struct BroadcastJoin { impl BroadcastJoin { pub(crate) fn new( - left: Arc, - right: Arc, + broadcaster: Arc, + receiver: Arc, left_on: Vec, right_on: Vec, join_type: JoinType, is_swapped: bool, ) -> Self { Self { - left, - right, + broadcaster, + receiver, left_on, right_on, join_type, diff --git a/src/daft-plan/src/physical_plan.rs b/src/daft-plan/src/physical_plan.rs index 11d2678600..5f6f6e8631 100644 --- a/src/daft-plan/src/physical_plan.rs +++ b/src/daft-plan/src/physical_plan.rs @@ -160,7 +160,9 @@ impl PhysicalPlan { .into(), } } - Self::BroadcastJoin(BroadcastJoin { right, .. }) => right.partition_spec(), + Self::BroadcastJoin(BroadcastJoin { + receiver: right, .. + }) => right.partition_spec(), Self::TabularWriteParquet(TabularWriteParquet { input, .. }) => input.partition_spec(), Self::TabularWriteCsv(TabularWriteCsv { input, .. }) => input.partition_spec(), Self::TabularWriteJson(TabularWriteJson { input, .. }) => input.partition_spec(), @@ -639,8 +641,8 @@ impl PhysicalPlan { Ok(py_iter.into()) } PhysicalPlan::BroadcastJoin(BroadcastJoin { - left, - right, + broadcaster: left, + receiver: right, left_on, right_on, join_type, diff --git a/src/daft-plan/src/planner.rs b/src/daft-plan/src/planner.rs index 9390efc048..173ce9f795 100644 --- a/src/daft-plan/src/planner.rs +++ b/src/daft-plan/src/planner.rs @@ -492,25 +492,20 @@ pub fn plan(logical_plan: &LogicalPlan, cfg: Arc) -> DaftResult { if right_size_bytes < left_size_bytes { - do_swap = true; - Some(right_size_bytes) + (Some(right_size_bytes), true) } else { - Some(left_size_bytes) + (Some(left_size_bytes), false) } } - (Some(left_size_bytes), None) => Some(left_size_bytes), - (None, Some(right_size_bytes)) => { - do_swap = true; - Some(right_size_bytes) - } - (None, None) => None, + (Some(left_size_bytes), None) => (Some(left_size_bytes), false), + (None, Some(right_size_bytes)) => (Some(right_size_bytes), true), + (None, None) => (None, false), }; // If smaller table is under broadcast size threshold, use broadcast join. if let Some(smaller_size_bytes) = smaller_size_bytes && smaller_size_bytes <= cfg.broadcast_join_size_bytes_threshold { diff --git a/tests/dataframe/test_joins.py b/tests/dataframe/test_joins.py index a376aacf3f..a538bb5c53 100644 --- a/tests/dataframe/test_joins.py +++ b/tests/dataframe/test_joins.py @@ -12,7 +12,7 @@ @pytest.fixture(params=[False, True]) def broadcast_join_enabled(request): # Toggles between default broadcast join threshold (10 MiB), and a threshold of 0, which disables broadcast joins. - broadcast_threshold = 10 * 1024 * 1024 if request.param else False + broadcast_threshold = 10 * 1024 * 1024 if request.param else 0 old_context = daft.context.pop_context() try: yield daft.context.set_config(broadcast_join_size_bytes_threshold=broadcast_threshold) From 68160ab971d44468d036b590f7a665142f1695ed Mon Sep 17 00:00:00 2001 From: clarkzinzow Date: Mon, 11 Dec 2023 10:27:28 -0800 Subject: [PATCH 3/4] Don't use broadcast join if larger table is already partitioned on join key. --- src/daft-plan/src/planner.rs | 43 ++++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/src/daft-plan/src/planner.rs b/src/daft-plan/src/planner.rs index 173ce9f795..b8d1018b9e 100644 --- a/src/daft-plan/src/planner.rs +++ b/src/daft-plan/src/planner.rs @@ -488,6 +488,23 @@ pub fn plan(logical_plan: &LogicalPlan, cfg: Arc) -> DaftResult) -> DaftResult (Some(right_size_bytes), true), (None, None) => (None, false), }; - // If smaller table is under broadcast size threshold, use broadcast join. - if let Some(smaller_size_bytes) = smaller_size_bytes && smaller_size_bytes <= cfg.broadcast_join_size_bytes_threshold { + let is_larger_partitioned = if do_swap { + is_right_partitioned + } else { + is_left_partitioned + }; + // If larger table is not already partitioned on the join key AND the smaller table is under broadcast size threshold, use broadcast join. + if !is_larger_partitioned && let Some(smaller_size_bytes) = smaller_size_bytes && smaller_size_bytes <= cfg.broadcast_join_size_bytes_threshold { if do_swap { // These will get swapped back when doing the actual local joins. (left_physical, right_physical) = (right_physical, left_physical); } return Ok(PhysicalPlan::BroadcastJoin(BroadcastJoin::new(left_physical.into(), right_physical.into(), left_on.clone(), right_on.clone(), *join_type, do_swap))); } - let left_pspec = left_physical.partition_spec(); - let right_pspec = right_physical.partition_spec(); - let num_partitions = max(left_pspec.num_partitions, right_pspec.num_partitions); - let new_left_pspec = Arc::new(PartitionSpec::new_internal( - PartitionScheme::Hash, - num_partitions, - Some(left_on.clone()), - )); - let new_right_pspec = Arc::new(PartitionSpec::new_internal( - PartitionScheme::Hash, - num_partitions, - Some(right_on.clone()), - )); if (num_partitions > 1 || left_pspec.num_partitions != num_partitions) - && left_pspec != new_left_pspec + && !is_left_partitioned { let split_op = PhysicalPlan::FanoutByHash(FanoutByHash::new( left_physical.into(), @@ -539,7 +548,7 @@ pub fn plan(logical_plan: &LogicalPlan, cfg: Arc) -> DaftResult 1 || right_pspec.num_partitions != num_partitions) - && right_pspec != new_right_pspec + && !is_right_partitioned { let split_op = PhysicalPlan::FanoutByHash(FanoutByHash::new( right_physical.into(), From d02e5dbd4018e92d8c4eaaeb9a2c29d266d44625 Mon Sep 17 00:00:00 2001 From: clarkzinzow Date: Mon, 11 Dec 2023 10:28:17 -0800 Subject: [PATCH 4/4] Expand size approximation to more ops. --- src/daft-plan/src/physical_plan.rs | 42 ++++++++++++++++++++++++++++-- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/src/daft-plan/src/physical_plan.rs b/src/daft-plan/src/physical_plan.rs index 5f6f6e8631..68112897ee 100644 --- a/src/daft-plan/src/physical_plan.rs +++ b/src/daft-plan/src/physical_plan.rs @@ -179,6 +179,14 @@ impl PhysicalPlan { .iter() .map(|scan_task| scan_task.size_bytes()) .sum::>(), + // Assume no row/column pruning in cardinality-affecting operations. + // TODO(Clark): Estimate row/column pruning to get a better size approximation. + Self::Filter(Filter { input, .. }) + | Self::Limit(Limit { input, .. }) + | Self::Project(Project { input, .. }) => input.approximate_size_bytes(), + // Assume ~the same size in bytes for explodes. + // TODO(Clark): Improve this estimate. + Self::Explode(Explode { input, .. }) => input.approximate_size_bytes(), // Propagate child approximation for operations that don't affect cardinality. Self::Coalesce(Coalesce { input, .. }) | Self::FanoutByHash(FanoutByHash { input, .. }) @@ -188,8 +196,38 @@ impl PhysicalPlan { | Self::ReduceMerge(ReduceMerge { input, .. }) | Self::Sort(Sort { input, .. }) | Self::Split(Split { input, .. }) => input.approximate_size_bytes(), - // TODO(Clark): Add size bytes estimates for other operators. - _ => None, + Self::Concat(Concat { input, other }) => { + input.approximate_size_bytes().and_then(|input_size| { + other + .approximate_size_bytes() + .map(|other_size| input_size + other_size) + }) + } + // Assume a simple sum of the sizes of both sides of the join for the post-join size. + // TODO(Clark): This will double-count join key columns, we should ensure that these are only counted once. + Self::BroadcastJoin(BroadcastJoin { + broadcaster: left, + receiver: right, + .. + }) + | Self::HashJoin(HashJoin { left, right, .. }) => { + left.approximate_size_bytes().and_then(|left_size| { + right + .approximate_size_bytes() + .map(|right_size| left_size + right_size) + }) + } + // TODO(Clark): Approximate post-aggregation sizes via grouping estimates + aggregation type. + Self::Aggregate(_) => None, + // No size approximation support for legacy I/O. + Self::TabularScanParquet(_) | Self::TabularScanCsv(_) | Self::TabularScanJson(_) => { + None + } + // Post-write DataFrame will contain paths to files that were written. + // TODO(Clark): Estimate output size via root directory and estimates for # of partitions given partitioning column. + Self::TabularWriteParquet(_) | Self::TabularWriteCsv(_) | Self::TabularWriteJson(_) => { + None + } } } }