diff --git a/daft/context.py b/daft/context.py index caf74ef4d6..0b493b4484 100644 --- a/daft/context.py +++ b/daft/context.py @@ -122,6 +122,11 @@ def runner(self) -> Runner: with self._lock: return self._get_runner() + def shuffle_service_factory(self): + from daft.runners.ray_runner import RayShuffleServiceFactory + + return RayShuffleServiceFactory() + @property def daft_execution_config(self) -> PyDaftExecutionConfig: with self._lock: diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index 91be85e02a..3b0bff13a1 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -57,8 +57,9 @@ from pyiceberg.schema import Schema as IcebergSchema from pyiceberg.table import TableProperties as IcebergTableProperties - from daft.daft import FileFormat, IOConfig, JoinType + from daft.daft import FileFormat, IOConfig, JoinType, PyExpr from daft.logical.schema import Schema + from daft.runners.partitioning import PartialPartitionMetadata # A PhysicalPlan that is still being built - may yield both PartitionTaskBuilders and PartitionTasks. @@ -1615,6 +1616,194 @@ def fanout_random(child_plan: InProgressPhysicalPlan[PartitionT], num_partitions seed += 1 +def fully_materializing_push_exchange_op( + child_plan: InProgressPhysicalPlan[PartitionT], partition_by: list[PyExpr], num_partitions: int +) -> InProgressPhysicalPlan[PartitionT]: + from daft.expressions import Expression + + # Step 1: Naively materialize all child partitions + stage_id_children = next(stage_id_counter) + materialized_partitions: list[SingleOutputPartitionTask] = [] + for step in child_plan: + if isinstance(step, PartitionTaskBuilder): + task = step.finalize_partition_task_single_output(stage_id=stage_id_children) + materialized_partitions.append(task) + yield task + elif isinstance(step, PartitionTask): + yield step + elif step is None: + yield None + else: + yield step + + # Step 2: Wait for all partitions to be done + while any(not p.done() for p in materialized_partitions): + yield None + + with get_context().shuffle_service_factory().push_based_shuffle_service_context( + num_partitions, partition_by=ExpressionsProjection([Expression._from_pyexpr(e) for e in partition_by]) + ) as shuffle_service: + results = shuffle_service.run([p.partition() for p in materialized_partitions]) + + for reduced_data in results: + reduce_task = PartitionTaskBuilder( + inputs=[reduced_data], + partial_metadatas=None, + resource_request=ResourceRequest(), + ) + yield reduce_task + + +def fully_materializing_exchange_op( + child_plan: InProgressPhysicalPlan[PartitionT], partition_by: list[PyExpr], num_partitions: int +) -> InProgressPhysicalPlan[PartitionT]: + from daft.expressions import Expression + + # Step 1: Naively materialize all child partitions + stage_id_children = next(stage_id_counter) + materialized_partitions: list[SingleOutputPartitionTask] = [] + for step in child_plan: + if isinstance(step, PartitionTaskBuilder): + task = step.finalize_partition_task_single_output(stage_id=stage_id_children) + materialized_partitions.append(task) + yield task + elif isinstance(step, PartitionTask): + yield step + elif step is None: + yield None + else: + yield step + + # Step 2: Wait for all partitions to be done + while any(not p.done() for p in materialized_partitions): + yield None + + # Step 3: Yield the map tasks + stage_id_map_tasks = next(stage_id_counter) + materialized_map_partitions: list[MultiOutputPartitionTask] = [] + while materialized_partitions: + materialized_child_partition = materialized_partitions.pop(0) + map_task = ( + PartitionTaskBuilder( + inputs=[materialized_child_partition.partition()], + partial_metadatas=materialized_child_partition.partial_metadatas, + resource_request=ResourceRequest(), + ) + .add_instruction( + execution_step.FanoutHash( + _num_outputs=num_partitions, + partition_by=ExpressionsProjection([Expression._from_pyexpr(expr) for expr in partition_by]), + ), + ResourceRequest(), + ) + .finalize_partition_task_multi_output(stage_id=stage_id_map_tasks) + ) + materialized_map_partitions.append(map_task) + yield map_task + + # Step 4: Wait on all the map tasks to complete + while any(not p.done() for p in materialized_map_partitions): + yield None + + # Step 5: "Transpose the results" and run reduce tasks + transposed_results: list[list[tuple[PartitionT, PartialPartitionMetadata]]] = [[] for _ in range(num_partitions)] + for map_task in materialized_map_partitions: + partitions = map_task.partitions() + partition_metadatas = map_task.partial_metadatas + for i, (partition, meta) in enumerate(zip(partitions, partition_metadatas)): + transposed_results[i].append((partition, meta)) + + for i, partitions in enumerate(transposed_results): + reduce_task = PartitionTaskBuilder( + inputs=[p for p, _ in partitions], + partial_metadatas=[m for _, m in partitions], + resource_request=ResourceRequest(), + ).add_instruction( + instruction=execution_step.ReduceMerge(), + ) + yield reduce_task + + +# This was the complicated one... +# +# def fully_materializing_exchange_op( +# child_plan: InProgressPhysicalPlan[PartitionT], partition_by: list[PyExpr], num_partitions: int +# ) -> InProgressPhysicalPlan[PartitionT]: +# from daft.execution.physical_plan_shuffles import HashPartitionRequest + +# prior_stage_id = next(stage_id_counter) + +# # Yield children stuff and avoid creating the shuffle service until we start running +# # tasks in the stage directly prior to this +# child_task = None +# while child_task is None: +# step = next(child_plan) +# if step is None: +# yield None +# continue +# elif isinstance(step, PartitionTask): +# yield step +# continue +# else: +# assert isinstance(step, PartitionTaskBuilder) +# child_task = step.finalize_partition_task_single_output(prior_stage_id) +# break + +# materializations: deque[SingleOutputPartitionTask] = deque() +# materializations.append(child_task) +# yield child_task + +# MAX_NUM_CHILD_INFLIGHT_TASKS_BEFORE_INGESTION = 128 + +# # Create the shuffle service and start materializing children and sending data to the service +# with get_context().shuffle_service_factory().fully_materializing_shuffle_service_context( +# num_partitions, +# [c.name() for c in partition_by], # TODO: Assume no-op here for now, YOLO! +# ) as shuffle_service: +# child_plan_exhausted = False +# while not child_plan_exhausted or len(materializations) > 0: +# # Ingest as many materialized results as possible +# materialized: list[SingleOutputPartitionTask] = [] +# while len(materializations) > 0 and materializations[0].done(): +# materialized.append(materializations.popleft()) +# results = [done_task.result() for done_task in materialized] +# if len(results) > 0: +# _ingest_results = shuffle_service.ingest([r.partition for r in results]) + +# # Keep pulling steps from children until either: +# # 1. We hit `MAX_NUM_CHILD_INFLIGHT_TASKS_BEFORE_INGESTION` and want to chill a bit to ingest the data +# # 2. We exhaust the child plan +# while len(materializations) < MAX_NUM_CHILD_INFLIGHT_TASKS_BEFORE_INGESTION: +# try: +# step = next(child_plan) +# except StopIteration: +# child_plan_exhausted = True +# break +# if step is None: +# yield step +# elif isinstance(step, PartitionTask): +# yield step +# else: +# assert isinstance(step, PartitionTaskBuilder) +# child_task = step.finalize_partition_task_single_output(prior_stage_id) +# materializations.append(child_task) +# yield child_task + +# # Read from the shuffle service in chunks of 1GB +# MAX_SIZE_BYTES = 1024 * 1024 * 1024 +# partition_requests = [ +# list(shuffle_service.read(HashPartitionRequest(type_="hash", bucket=i), MAX_SIZE_BYTES)) +# for i in range(num_partitions) +# ] + +# for partition_chunks in partition_requests: +# yield PartitionTaskBuilder[PartitionT]( +# inputs=partition_chunks, +# partial_metadatas=None, +# resource_request=ResourceRequest(), +# ) + + def _best_effort_next_step( stage_id: int, child_plan: InProgressPhysicalPlan[PartitionT] ) -> tuple[PartitionTask[PartitionT] | None, bool]: diff --git a/daft/execution/physical_plan_shuffles.py b/daft/execution/physical_plan_shuffles.py new file mode 100644 index 0000000000..0d57f7709e --- /dev/null +++ b/daft/execution/physical_plan_shuffles.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import dataclasses +from typing import Any, Iterator, Literal, Protocol, TypeVar + +ShuffleData = TypeVar("ShuffleData") +IngestResult = TypeVar("IngestResult") + + +@dataclasses.dataclass(frozen=True) +class PartitioningSpec: + type_: Literal["hash"] | Literal["range"] + + def to_hash_pspec(self) -> HashPartitioningSpec: + assert self.type_ == "hash" and isinstance(self, HashPartitioningSpec) + return self + + def to_range_pspec(self) -> RangePartitioningSpec: + assert self.type_ == "range" and isinstance(self, RangePartitioningSpec) + return self + + +@dataclasses.dataclass(frozen=True) +class HashPartitioningSpec(PartitioningSpec): + num_partitions: int + columns: list[str] + + +@dataclasses.dataclass(frozen=True) +class RangePartitioningSpec(PartitioningSpec): + boundaries: list[Any] + columns: list[str] + + +@dataclasses.dataclass(frozen=True) +class PartitionRequest: + type_: Literal["hash"] | Literal["range"] + + def to_hash_request(self) -> HashPartitionRequest: + assert self.type_ == "hash" and isinstance(self, HashPartitionRequest) + return self + + def to_range_request(self) -> RangePartitionRequest: + assert self.type_ == "range" and isinstance(self, RangePartitionRequest) + return self + + +@dataclasses.dataclass(frozen=True) +class HashPartitionRequest(PartitionRequest): + bucket: int + + +@dataclasses.dataclass(frozen=True) +class RangePartitionRequest(PartitionRequest): + start_end_values: list[tuple[Any, Any]] + + +class ShuffleServiceInterface(Protocol[ShuffleData, IngestResult]): + """An interface to a ShuffleService + + The job of a shuffle service is to `.ingest` results from the previous stage, perform partitioning on the data, + and then expose a `.read` to consumers of the "shuffled" data. + + NOTE: `.read` should throw an error before the ShuffleService is informed of the target partitioning. This + is because the ShuffleService needs to know how to partition the data before it can emit results. + + See BigQuery/Dremel video from CMU: https://www.youtube.com/watch?v=JxeITDS-xh0&ab_channel=CMUDatabaseGroup + """ + + def teardown(self) -> None: ... + + ### + # INGESTION: + # These endpoints allow the ShuffleService to ingest data from the previous stage of the query + ### + + def ingest(self, data: Iterator[ShuffleData]) -> list[IngestResult]: + """Receive some data. + + NOTE: This will throw an error if called after `.close_ingest` has been called. + """ + ... + + def set_input_stage_completed(self) -> None: + """Inform the ShuffleService that all data from the previous stage has been ingested""" + ... + + def is_input_stage_completed(self) -> bool: + """Query whether or not the previous stage has completed ingestion""" + ... + + ### + # READ: + # These endpoints allow clients to request data from the ShuffleService + ### + + def read(self, request: PartitionRequest, max_size_bytes: int) -> Iterator[ShuffleData]: + """Retrieves ShuffleData from the shuffle service for the specified partition. + + This returns an iterator of ShuffleData + + When all data is guaranteed to be exhausted for the given request, the iterator will raise + a StopIteration. + """ + ... + + # TODO: Dynamic Partitioning + # + # We could have the ShuffleService expose running statistics (as data is being collected) + # so that the coordinator can dynamically decide on an appropriate output partitioning scheme + # before it attempts to request for output ShuffleData + # + # def get_current_statistics(self) -> ShuffleStatistics: + # """Retrieves the current statistics from the ShuffleService's currently ingested data""" + # ... + # + # def set_output_partitioning(self, spec: PartitioningSpec) -> None: + # """Sets the intended output partitioning scheme that should be emitted""" + # ... diff --git a/daft/runners/pyrunner.py b/daft/runners/pyrunner.py index 31c56c3ad4..20db206bf9 100644 --- a/daft/runners/pyrunner.py +++ b/daft/runners/pyrunner.py @@ -361,6 +361,12 @@ def actor_pool_context( self._actor_pools[actor_pool_id].teardown() del self._actor_pools[actor_pool_id] + @contextlib.contextmanager + def shuffle_service_context( + self, + ) -> Iterator[str]: + raise NotImplementedError("shuffle_service_context not yet implemented in PyRunner") + def _physical_plan_to_partitions( self, execution_id: str, diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index d29a15c9f2..bebff472fe 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -13,11 +13,13 @@ # import times. If this changes, we first need to make the daft.lazy_import.LazyImport class # serializable before importing pa from daft.dependencies. import pyarrow as pa # noqa: TID253 +import ray.experimental # noqa: TID253 from daft.arrow_utils import ensure_array from daft.context import execution_config_ctx, get_context from daft.daft import PyTable as _PyTable from daft.dependencies import np +from daft.expressions import ExpressionsProjection from daft.runners.progress_bar import ProgressBar from daft.series import Series, item_to_series from daft.table import Table @@ -51,7 +53,6 @@ SingleOutputPartitionTask, StatefulUDFProject, ) -from daft.expressions import ExpressionsProjection from daft.filesystem import glob_path_with_stats from daft.runners import runner_io from daft.runners.partitioning import ( @@ -1252,3 +1253,404 @@ def from_metadata_list(cls, meta: list[PartitionMetadata]) -> PartitionMetadataA accessor = cls(ref) accessor._metadatas = meta return accessor + + +### +# ShuffleService +### + + +from daft.execution.physical_plan_shuffles import ( + HashPartitionRequest, + PartitioningSpec, + PartitionRequest, + ShuffleServiceInterface, +) + +ray.ObjectRef = ray.ObjectRef + + +@ray.remote +class ShuffleServiceActor: + """An actor that is part of the ShuffleService. This is meant to be spun up one-per-node""" + + def __init__(self, partition_spec: PartitioningSpec): + assert partition_spec.type_ == "hash", "Only hash partitioning is currently supported" + self._output_partitioning = partition_spec + hash_spec = self._output_partitioning.to_hash_pspec() + + # TODO: These can be made much more sophisticated, performing things such as disk/remote spilling + # as necessary. + self._unpartitioned_data_buffer: list[MicroPartition] = [] + self._partitioned_data_buffer: dict[PartitionRequest, list[MicroPartition]] = { + HashPartitionRequest(type_="hash", bucket=i): [] for i in range(hash_spec.num_partitions) + } + + @ray.method(num_returns=2) + def read(self, request: PartitionRequest, chunk_size_bytes: int) -> tuple[MicroPartition | None, int]: + """Retrieves ShuffleData from the shuffle service for the specified partition""" + if len(self._unpartitioned_data_buffer) > 0: + self._run_partitioning() + + assert request in self._partitioned_data_buffer, f"PartitionRequest must be in the buffer: {request}" + + buffer = self._partitioned_data_buffer[request] + if buffer: + result = buffer.pop(0) + + # Perform merging of result with the next results in the buffer + result_size: int = result.size_bytes() # type: ignore + assert ( + result_size is not None + ), "We really need the size here, but looks like lazy MPs without stats don't give us the size" + + while result_size < chunk_size_bytes and buffer: + next_partition = buffer[0] + next_partition_size = next_partition.size_bytes() + + assert ( + next_partition_size is not None + ), "We really need the size here, but looks like lazy MPs without stats don't give us the size" + + if result_size + next_partition_size <= chunk_size_bytes: + result = MicroPartition.concat([result, buffer.pop(0)]) + result_size = result.size_bytes() # type: ignore + else: + break + + # Split the result if it is too large (best-effort) + if result_size > chunk_size_bytes: + rows_per_byte = len(result) / result_size + target_rows = int(chunk_size_bytes * rows_per_byte) + result, remaining = result.slice(0, target_rows), result.slice(target_rows, len(result)) + buffer.insert(0, remaining) + + return (result, result.size_bytes()) # type: ignore + else: + return (None, 0) + + def ingest(self, data: list[MicroPartition]) -> None: + """Ingest data into the shuffle service""" + self._unpartitioned_data_buffer.extend(data) + return None + + def _run_partitioning(self) -> None: + from daft import col + + assert self._output_partitioning.type_ == "hash", "Only hash partitioning is currently supported" + hash_spec = self._output_partitioning.to_hash_pspec() + + # Drain `self._unpartitioned_data_buffer`, partitioning the data and placing it into `self._partitioned_data_buffer` + while self._unpartitioned_data_buffer: + micropartition = self._unpartitioned_data_buffer.pop(0) + partitions = micropartition.partition_by_hash( + ExpressionsProjection([col(c) for c in hash_spec.columns]), hash_spec.num_partitions + ) + for partition_request, partition in zip(self._partitioned_data_buffer.keys(), partitions): + self._partitioned_data_buffer[partition_request].append(partition) + + +class RayPerNodeActorFullyMaterializingShuffleService(ShuffleServiceInterface[ray.ObjectRef, ray.ObjectRef]): + """A ShuffleService implementation in Ray that utilizes Ray Actors on each node to perform a shuffle + + This is nice because it lets us `.ingest` data into each node's Actor before actually performing the shuffle, + reducing the complexity of the operation to O(num_nodes^2) instead of O(input_partitions * output_partitions) + """ + + # Default amount of bytes to request from Ray for each ShuffleActor + DEFAULT_SHUFFLE_ACTOR_MEMORY_REQUEST_BYTES: int = 1024 * 1024 * 1024 + + def __init__( + self, + output_partition_spec: PartitioningSpec, + per_actor_memory_request_bytes: int = DEFAULT_SHUFFLE_ACTOR_MEMORY_REQUEST_BYTES, + ): + self._input_stage_completed = False + self._output_partitioning_spec = output_partition_spec + + # Mapping of {node_id (str): Actor}, create one Actor per node + self._actors: dict[str, ShuffleServiceActor] = {} + self._placement_groups: dict[str, ray.util.placement_group.PlacementGroup] = {} + nodes = ray.nodes() + for node in nodes: + node_id = node["NodeID"] + pg = ray.util.placement_group( + [ + {"CPU": 0.01, "memory": per_actor_memory_request_bytes} + ], # Use minimal CPU and 100MB memory to avoid interfering with other tasks + strategy="STRICT_SPREAD", + ) + ray.get(pg.ready()) + self._placement_groups[node_id] = pg + for node_id, pg in self._placement_groups.items(): + actor = ShuffleServiceActor.options( # type: ignore + num_cpus=0.01, placement_group=pg, placement_group_bundle_index=0 + ).remote(self._output_partitioning_spec) + self._actors[node_id] = actor + + def teardown(self) -> None: + for actor in self._actors.values(): + ray.kill(actor) + for pg in self._placement_groups.values(): + ray.util.remove_placement_group(pg) + self._actors.clear() + self._placement_groups.clear() + + def ingest(self, data: Iterator[ray.ObjectRef]) -> list[ray.ObjectRef[None]]: + """Receive some data + + NOTE: This will throw an error if called after `.close_ingest` has been called. + """ + if self._input_stage_completed: + raise RuntimeError("Cannot ingest data after input stage is completed.") + + data_list = list(data) + best_effort_object_locations = ray.experimental.get_object_locations(data_list) + + # Get the corresponding actor for this node + ingestion_results = [] + for objref in data_list: + best_effort_node_id = best_effort_object_locations.get(objref) + if best_effort_node_id is None: + raise NotImplementedError( + "Need to implement fallback logic when location information unavailable in Ray for a partition. We can probably just do a separate round-robin assignment of shuffle Actors. This is expected to happen when the partition is small I think (< 100KB) since then it wouldn't be located in the plasma store, and won't have location info." + ) + if best_effort_node_id not in self._actors: + raise RuntimeError(f"No ShuffleServiceActor found for node {best_effort_node_id}") + actor = self._actors[best_effort_node_id] + ingestion_results.append(actor.ingest.remote(data_list)) # type: ignore + + return ingestion_results + + def set_input_stage_completed(self) -> None: + """Inform the ShuffleService that all data from the previous stage has been ingested""" + self._input_stage_completed = True + + def is_input_stage_completed(self) -> bool: + """Query whether or not the previous stage has completed ingestion""" + return self._input_stage_completed + + def read(self, request: PartitionRequest, chunk_size_bytes: int) -> Iterator[ray.ObjectRef]: + """Retrieves ShuffleData from the shuffle service for the specified partition. + + NOTE: This will throw an error if called before `set_output_partitioning` is called. + """ + # TODO: We currently enforce full materialization of the previous stage's outputs before allowing + # The subsequent stage to start reading data. + # + # This is also called a "pipeline breaker". + if not self.is_input_stage_completed(): + return None + + # Validate that the PartitionRequest matches the provided spec + if self._output_partitioning_spec.type_ != request.type_: + raise ValueError( + f"PartitionRequest type '{request.type_}' does not match the partitioning spec type '{self._output_partitioning_spec.type_}'" + ) + + # Validate the incoming PartitionRequest + hash_spec = self._output_partitioning_spec.to_hash_pspec() + hash_request = request.to_hash_request() + if hash_request.bucket >= hash_spec.num_partitions: + raise ValueError( + f"Requested bucket {hash_request.bucket} is out of range for {hash_spec.num_partitions} partitions" + ) + + # Iterate through all actors and yield ray.ObjectRef + for actor in self._actors.values(): + while True: + # Request data from the actor + (micropartition_ref, actual_bytes_read) = actor.read.remote(request, chunk_size_bytes) + + # TODO: Pretty sure this is really slow. Not sure what a good workaround is though since + # There is't a great mechanism to query the actors quickly regarding the remaining state + # + # Because we know that `self.is_input_stage_completed() == True`, + # we know that there won't be any more data coming in. So we can YOLO and stop + # the iteration. + actual_bytes_read = ray.get(actual_bytes_read) + if actual_bytes_read == 0: + break + + yield micropartition_ref + + # After all actors are exhausted, we're done + return + + +class RayShuffleServiceFactory: + @contextlib.contextmanager + def fully_materializing_shuffle_service_context( + self, + num_partitions: int, + columns: list[str], + ) -> Iterator[RayPerNodeActorFullyMaterializingShuffleService]: + from daft.execution.physical_plan_shuffles import HashPartitioningSpec + + shuffle_service = RayPerNodeActorFullyMaterializingShuffleService( + HashPartitioningSpec(type_="hash", num_partitions=num_partitions, columns=columns) + ) + yield shuffle_service + shuffle_service.teardown() + + @contextlib.contextmanager + def push_based_shuffle_service_context( + self, + num_partitions: int, + partition_by: ExpressionsProjection, + ) -> Iterator[RayPushBasedShuffle]: + num_cpus = int(ray.cluster_resources()["CPU"]) + + # Number of mappers is ~2x number of mergers + num_merge_tasks = num_cpus // 3 + num_map_tasks = num_merge_tasks * 2 + (num_cpus % 3) + + yield RayPushBasedShuffle(num_map_tasks, num_merge_tasks, num_partitions, partition_by) + + +@ray.remote +def map_fn( + map_input: MicroPartition, num_mergers: int, partition_by: ExpressionsProjection, num_partitions: int +) -> list[list[MicroPartition]]: + """Returns `N` number of inputs, where `N` is the number of mergers""" + # Partition the input data based on the partitioning spec + partitioned_data = map_input.partition_by_hash(partition_by, num_partitions) + + outputs: list[list[MicroPartition]] = [[] for _ in range(num_mergers)] + + # Calculate the base number of partitions per merger and the number of mergers that get an extra partition + base_partitions_per_merger = num_partitions // num_mergers + extra_partitions = num_partitions % num_mergers + + num_mergers_with_extra_partitions = extra_partitions + num_partitions_assigned_to_mergers_with_extra_partitions = (base_partitions_per_merger + 1) * extra_partitions + + # Distribute the partitioned data across the mergers + for partition_idx, partition in enumerate(partitioned_data): + if partition_idx < num_partitions_assigned_to_mergers_with_extra_partitions: + # For the first 'extra_partitions' mergers, assign base + 1 partitions + merger_idx = partition_idx // (base_partitions_per_merger + 1) + else: + # For the remaining mergers, assign base partitions + merger_idx = ( + num_mergers_with_extra_partitions + + (partition_idx - num_partitions_assigned_to_mergers_with_extra_partitions) + // base_partitions_per_merger + ) + + outputs[merger_idx].append(partition) + + return outputs + + +@ray.remote +def merge_fn(*merger_inputs_across_mappers: list[MicroPartition]) -> list[MicroPartition]: + """Returns `P` number of inputs, where `P` is the number of reducers assigned to this merger""" + num_partitions_for_this_merger = len(merger_inputs_across_mappers[0]) + merged_partitions = [] + for partition_idx in range(num_partitions_for_this_merger): + partitions_to_merge = [data[partition_idx] for data in merger_inputs_across_mappers] + merged_partition = MicroPartition.concat(partitions_to_merge) + merged_partitions.append(merged_partition) + return merged_partitions + + +@ray.remote +def reduce_fn(*reduce_inputs_across_rounds: MicroPartition) -> MicroPartition: + """Returns 1 output, which is the reduced data across rounds""" + # Concatenate all input MicroPartitions across rounds + reduced_partition = MicroPartition.concat(list(reduce_inputs_across_rounds)) + + # Return the result as a list containing a single MicroPartition + return reduced_partition + + +class RayPushBasedShuffle: + def __init__(self, num_mappers: int, num_mergers: int, num_reducers: int, partition_by: ExpressionsProjection): + self._num_mappers = num_mappers + self._num_mergers = num_mergers + self._num_reducers = num_reducers + self._partition_by = partition_by + + def _num_reducers_for_merger(self, merger_idx: int) -> int: + base_num = self._num_reducers // self._num_mergers + extra = self._num_reducers % self._num_mergers + if merger_idx < extra: + return base_num + 1 + return base_num + + def _get_reducer_inputs_location(self, reducer_idx: int) -> tuple[int, int]: + """Returns the (merger_idx, offset) of where the inputs to a given reducer should live""" + for merger_idx in range(self._num_mergers): + num_reducers = self._num_reducers_for_merger(merger_idx) + if num_reducers > reducer_idx: + return merger_idx, reducer_idx + else: + reducer_idx -= num_reducers + raise ValueError(f"Cannot find merger for reducer_idx: {reducer_idx}") + + def _merger_options(self, merger_idx: int) -> dict[str, Any]: + num_nodes = len(ray.nodes()) + node_id = ray.nodes()[merger_idx % num_nodes]["NodeID"] + return { + "scheduling_strategy": ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(node_id, soft=True) + } + + def _reduce_options(self, reducer_idx: int) -> dict[str, Any]: + assigned_merger_idx, _ = self._get_reducer_inputs_location(reducer_idx) + return self._merger_options(assigned_merger_idx) + + def run(self, materialized_inputs: list[ray.ObjectRef]) -> list[ray.ObjectRef]: + """Runs the Mappers and Mergers in a 2-stage pipeline until all mergers are materialized + + There are `R` reducers. + There are `N` mergers. Each merger is "responsible" for `R / N` reducers. + Each Mapper then should run partitioning on the data into `N` chunks. + """ + # [N_ROUNDS, N_MERGERS, N_REDUCERS_PER_MERGER] list of outputs + total_merge_results: list[list[list[ray.ObjectRef]]] = [] + map_results_buffer: list[ray.ObjectRef] = [] + + # Keep running the pipeline while there is still work to do + while materialized_inputs or map_results_buffer: + # Drain the map_results_buffer, running merge tasks + per_round_merge_results = [] + if map_results_buffer: + for merger_idx in range(self._num_mergers): + merger_input = [mapper_results[merger_idx] for mapper_results in map_results_buffer] + merge_results = merge_fn.options( + **self._merger_options(merger_idx), num_returns=self._num_reducers_for_merger(merger_idx) + ).remote(*merger_input) + per_round_merge_results.append(merge_results) + total_merge_results.append(per_round_merge_results) + + # Run map tasks: + map_results_buffer = [] + for i in range(self._num_mappers): + if len(materialized_inputs) == 0: + break + else: + map_input = materialized_inputs.pop(0) + map_results = map_fn.options(num_returns=self._num_mergers).remote( + map_input, self._num_mergers, self._partition_by, self._num_reducers + ) + map_results_buffer.append(map_results) + + # Wait for all tasks in this wave to complete + for results in per_round_merge_results: + ray.wait(results) + for results in map_results_buffer: + ray.wait(results) + + # INVARIANT: At this point, the map/merge step is done + # Start running all the reduce functions + # TODO: we could stagger this by num CPUs as well, but here we just YOLO run all + reduce_results = [] + for reducer_idx in range(self._num_reducers): + assigned_merger_idx, offset = self._get_reducer_inputs_location(reducer_idx) + reducer_inputs = [ + total_merge_results[round][assigned_merger_idx][offset] for round in range(len(total_merge_results)) + ] + res = reduce_fn.options(**self._reduce_options(reducer_idx)).remote(*reducer_inputs) + reduce_results.append(res) + return reduce_results diff --git a/daft/runners/runner.py b/daft/runners/runner.py index c1dd30f64e..74b64bb355 100644 --- a/daft/runners/runner.py +++ b/daft/runners/runner.py @@ -2,8 +2,9 @@ import contextlib from abc import abstractmethod -from typing import TYPE_CHECKING, Generic, Iterator +from typing import TYPE_CHECKING, Generic, Iterator, TypeVar +from daft.execution.physical_plan_shuffles import ShuffleServiceInterface from daft.runners.partitioning import ( MaterializedResult, PartitionCacheEntry, @@ -20,6 +21,9 @@ from daft.table import MicroPartition +ShuffleServiceImpl = TypeVar("ShuffleServiceImpl", bound=ShuffleServiceInterface) + + class Runner(Generic[PartitionT]): def __init__(self) -> None: self._part_set_cache = PartitionSetCache() diff --git a/src/daft-plan/src/display.rs b/src/daft-plan/src/display.rs index 76ba0e599a..22b035d782 100644 --- a/src/daft-plan/src/display.rs +++ b/src/daft-plan/src/display.rs @@ -56,6 +56,7 @@ impl TreeDisplay for crate::physical_plan::PhysicalPlan { Self::DeltaLakeWrite(write) => write.display_as(level), #[cfg(feature = "python")] Self::LanceWrite(write) => write.display_as(level), + Self::ExchangeOp(exchange_op) => exchange_op.display_as(level), } } diff --git a/src/daft-plan/src/physical_ops/exchange_op.rs b/src/daft-plan/src/physical_ops/exchange_op.rs new file mode 100644 index 0000000000..1ea5372484 --- /dev/null +++ b/src/daft-plan/src/physical_ops/exchange_op.rs @@ -0,0 +1,39 @@ +use std::sync::Arc; + +use serde::{Deserialize, Serialize}; + +use crate::{impl_default_tree_display, ClusteringSpec, PhysicalPlanRef}; + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct ExchangeOp { + pub input: PhysicalPlanRef, + pub strategy: ExchangeOpStrategy, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub enum ExchangeOpStrategy { + /// Fully materialize the data after the Map, and then pull results from the Reduce. + FullyMaterializingPull { target_spec: Arc }, + /// Stand up Reducers and then send data from the mappers into the reducers eagerly + FullyMaterializingPush { target_spec: Arc }, +} + +impl ExchangeOp { + pub fn multiline_display(&self) -> Vec { + let mut res = vec![]; + res.push("ExchangeOp:".to_string()); + match &self.strategy { + ExchangeOpStrategy::FullyMaterializingPull { target_spec } => { + res.push(" Strategy: FullyMaterializingPull".to_string()); + res.push(format!(" Target Spec: {:?}", target_spec)); + } + ExchangeOpStrategy::FullyMaterializingPush { target_spec } => { + res.push(" Strategy: FullyMaterializingPush".to_string()); + res.push(format!(" Target Spec: {:?}", target_spec)); + } + } + res + } +} + +impl_default_tree_display!(ExchangeOp); diff --git a/src/daft-plan/src/physical_ops/mod.rs b/src/daft-plan/src/physical_ops/mod.rs index 8a9a79a658..dfa6539969 100644 --- a/src/daft-plan/src/physical_ops/mod.rs +++ b/src/daft-plan/src/physical_ops/mod.rs @@ -8,6 +8,7 @@ mod csv; mod deltalake_write; mod empty_scan; +mod exchange_op; mod explode; mod fanout; mod filter; @@ -41,6 +42,7 @@ pub use csv::TabularWriteCsv; #[cfg(feature = "python")] pub use deltalake_write::DeltaLakeWrite; pub use empty_scan::EmptyScan; +pub use exchange_op::{ExchangeOp, ExchangeOpStrategy}; pub use explode::Explode; pub use fanout::{FanoutByHash, FanoutByRange, FanoutRandom}; pub use filter::Filter; diff --git a/src/daft-plan/src/physical_optimization/rules/reorder_partition_keys.rs b/src/daft-plan/src/physical_optimization/rules/reorder_partition_keys.rs index c7f706abb0..01fd1e5f19 100644 --- a/src/daft-plan/src/physical_optimization/rules/reorder_partition_keys.rs +++ b/src/daft-plan/src/physical_optimization/rules/reorder_partition_keys.rs @@ -4,7 +4,10 @@ use daft_dsl::{is_partition_compatible, ExprRef}; use crate::{ partitioning::HashClusteringConfig, - physical_ops::{Aggregate, Explode, FanoutByHash, HashJoin, Project, Unpivot}, + physical_ops::{ + Aggregate, ExchangeOp, ExchangeOpStrategy, Explode, FanoutByHash, HashJoin, Project, + Unpivot, + }, physical_optimization::{plan_context::PlanContext, rules::PhysicalOptimizerRule}, ClusteringSpec, PhysicalPlan, PhysicalPlanRef, }; @@ -130,6 +133,20 @@ impl PhysicalOptimizerRule for ReorderPartitionKeys { }); Ok(Transformed::yes(c.with_plan(new_plan.into()).propagate())) } + PhysicalPlan::ExchangeOp(ExchangeOp{input, strategy: ExchangeOpStrategy::FullyMaterializingPull { .. }}) => { + let new_plan = PhysicalPlan::ExchangeOp(ExchangeOp { + input: input.clone(), + strategy: ExchangeOpStrategy::FullyMaterializingPull { target_spec: new_spec.into() } + }); + Ok(Transformed::yes(c.with_plan(new_plan.into()).propagate())) + } + PhysicalPlan::ExchangeOp(ExchangeOp{input, strategy: ExchangeOpStrategy::FullyMaterializingPush { .. }}) => { + let new_plan = PhysicalPlan::ExchangeOp(ExchangeOp { + input: input.clone(), + strategy: ExchangeOpStrategy::FullyMaterializingPush { target_spec: new_spec.into() } + }); + Ok(Transformed::yes(c.with_plan(new_plan.into()).propagate())) + } // these depend solely on their input PhysicalPlan::Filter(..) | diff --git a/src/daft-plan/src/physical_plan.rs b/src/daft-plan/src/physical_plan.rs index 615d656b92..813cccf61e 100644 --- a/src/daft-plan/src/physical_plan.rs +++ b/src/daft-plan/src/physical_plan.rs @@ -45,6 +45,7 @@ pub enum PhysicalPlan { TabularWriteParquet(TabularWriteParquet), TabularWriteJson(TabularWriteJson), TabularWriteCsv(TabularWriteCsv), + ExchangeOp(ExchangeOp), #[cfg(feature = "python")] IcebergWrite(IcebergWrite), #[cfg(feature = "python")] @@ -240,6 +241,14 @@ impl PhysicalPlan { Self::IcebergWrite(_) | Self::DeltaLakeWrite(_) | Self::LanceWrite(_) => { ClusteringSpec::Unknown(UnknownClusteringConfig::new(1)).into() } + Self::ExchangeOp(ExchangeOp { + strategy: ExchangeOpStrategy::FullyMaterializingPull { target_spec }, + .. + }) + | Self::ExchangeOp(ExchangeOp { + strategy: ExchangeOpStrategy::FullyMaterializingPush { target_spec }, + .. + }) => target_spec.clone(), } } @@ -408,6 +417,7 @@ impl PhysicalPlan { Self::IcebergWrite(_) | Self::DeltaLakeWrite(_) | Self::LanceWrite(_) => { ApproxStats::empty() } + Self::ExchangeOp(ExchangeOp { input, .. }) => input.approximate_stats(), } } @@ -454,6 +464,7 @@ impl PhysicalPlan { Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId { input, .. }) => { vec![input] } + Self::ExchangeOp(ExchangeOp { input, .. }) => vec![input], } } @@ -493,6 +504,7 @@ impl PhysicalPlan { Self::DeltaLakeWrite(DeltaLakeWrite {schema, delta_lake_info, .. }) => Self::DeltaLakeWrite(DeltaLakeWrite::new(schema.clone(), delta_lake_info.clone(), input.clone())), #[cfg(feature = "python")] Self::LanceWrite(LanceWrite { schema, lance_info, .. }) => Self::LanceWrite(LanceWrite::new(schema.clone(), lance_info.clone(), input.clone())), + Self::ExchangeOp(ExchangeOp{strategy, ..}) => Self::ExchangeOp(ExchangeOp{ input: input.clone(), strategy: strategy.clone()}), Self::Concat(_) | Self::HashJoin(_) | Self::SortMergeJoin(_) | Self::BroadcastJoin(_) => panic!("{} requires more than 1 input, but received: {}", self, children.len()), }, [input1, input2] => match self { @@ -552,6 +564,14 @@ impl PhysicalPlan { Self::DeltaLakeWrite(..) => "DeltaLakeWrite", #[cfg(feature = "python")] Self::LanceWrite(..) => "LanceWrite", + Self::ExchangeOp(ExchangeOp { + strategy: ExchangeOpStrategy::FullyMaterializingPull { .. }, + .. + }) => "ExchangeOp[FullyMaterializing]", + Self::ExchangeOp(ExchangeOp { + strategy: ExchangeOpStrategy::FullyMaterializingPush { .. }, + .. + }) => "ExchangeOp[FullyMaterializingPush]", }; name.to_string() } @@ -596,6 +616,7 @@ impl PhysicalPlan { Self::DeltaLakeWrite(delta_lake_info) => delta_lake_info.multiline_display(), #[cfg(feature = "python")] Self::LanceWrite(lance_info) => lance_info.multiline_display(), + Self::ExchangeOp(exchange_op) => exchange_op.multiline_display(), } } diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index 639c571871..52197c5c2b 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -8,7 +8,7 @@ use common_daft_config::DaftExecutionConfig; use common_error::DaftResult; use common_file_formats::FileFormat; use daft_core::prelude::*; -use daft_dsl::{col, is_partition_compatible, ApproxPercentileParams, ExprRef, SketchType}; +use daft_dsl::{col, is_partition_compatible, ApproxPercentileParams, Expr, ExprRef, SketchType}; use daft_scan::PhysicalScanInfo; use crate::{ @@ -30,6 +30,42 @@ use crate::{ source_info::{PlaceHolderInfo, SourceInfo}, }; +/// Builds an exchange op (PhysicalPlan node(s) that shuffles data) +fn build_exchange_op( + input: PhysicalPlanRef, + num_partitions: usize, + partition_by: Vec>, +) -> PhysicalPlan { + let exchange_op = std::env::var("DAFT_EXCHANGE_OP"); + match exchange_op.as_deref() { + Err(_) => { + let split_op = + PhysicalPlan::FanoutByHash(FanoutByHash::new(input, num_partitions, partition_by)) + .arced(); + PhysicalPlan::ReduceMerge(ReduceMerge::new(split_op)) + } + Ok("pull") => PhysicalPlan::ExchangeOp(ExchangeOp { + input, + strategy: ExchangeOpStrategy::FullyMaterializingPull { + target_spec: Arc::new(ClusteringSpec::Hash(HashClusteringConfig::new( + num_partitions, + partition_by, + ))), + }, + }), + Ok("push") => PhysicalPlan::ExchangeOp(ExchangeOp { + input, + strategy: ExchangeOpStrategy::FullyMaterializingPush { + target_spec: Arc::new(ClusteringSpec::Hash(HashClusteringConfig::new( + num_partitions, + partition_by, + ))), + }, + }), + Ok(exo) => panic!("Unsupported DAFT_EXCHANGE_OP={exo}"), + } +} + pub(super) fn translate_single_logical_node( logical_plan: &LogicalPlan, physical_children: &mut Vec, @@ -206,6 +242,7 @@ pub(super) fn translate_single_logical_node( } } ClusteringSpec::Random(_) => { + // TODO: Support Random clustering spec for ExchangeOps let split_op = PhysicalPlan::FanoutRandom(FanoutRandom::new( input_physical, num_partitions, @@ -213,12 +250,7 @@ pub(super) fn translate_single_logical_node( PhysicalPlan::ReduceMerge(ReduceMerge::new(split_op.into())) } ClusteringSpec::Hash(HashClusteringConfig { by, .. }) => { - let split_op = PhysicalPlan::FanoutByHash(FanoutByHash::new( - input_physical, - num_partitions, - by.clone(), - )); - PhysicalPlan::ReduceMerge(ReduceMerge::new(split_op.into())) + build_exchange_op(input_physical, num_partitions, by.clone()) } ClusteringSpec::Range(_) => { unreachable!("Repartitioning by range is not supported") @@ -238,14 +270,10 @@ pub(super) fn translate_single_logical_node( PhysicalPlan::Aggregate(Aggregate::new(input_physical, vec![], col_exprs.clone())); let num_partitions = agg_op.clustering_spec().num_partitions(); if num_partitions > 1 { - let split_op = PhysicalPlan::FanoutByHash(FanoutByHash::new( - agg_op.into(), - num_partitions, - col_exprs.clone(), - )); - let reduce_op = PhysicalPlan::ReduceMerge(ReduceMerge::new(split_op.into())); + let exchange_op = + build_exchange_op(agg_op.into(), num_partitions, col_exprs.clone()); Ok( - PhysicalPlan::Aggregate(Aggregate::new(reduce_op.into(), vec![], col_exprs)) + PhysicalPlan::Aggregate(Aggregate::new(exchange_op.into(), vec![], col_exprs)) .arced(), ) } else { @@ -309,16 +337,15 @@ pub(super) fn translate_single_logical_node( )) .arced() } else { - let split_op = PhysicalPlan::FanoutByHash(FanoutByHash::new( + build_exchange_op( first_stage_agg, min( num_input_partitions, cfg.shuffle_aggregation_default_partitions, ), groupby.clone(), - )) - .arced(); - PhysicalPlan::ReduceMerge(ReduceMerge::new(split_op)).arced() + ) + .arced() }; let second_stage_agg = PhysicalPlan::Aggregate(Aggregate::new( @@ -378,7 +405,7 @@ pub(super) fn translate_single_logical_node( )) .arced() } else { - let split_op = PhysicalPlan::FanoutByHash(FanoutByHash::new( + build_exchange_op( first_stage_agg, min( num_input_partitions, @@ -387,9 +414,8 @@ pub(super) fn translate_single_logical_node( // NOTE: For the shuffle of a pivot operation, we don't include the pivot column for the hashing as we need // to ensure that all rows with the same group_by column values are hashed to the same partition. group_by.clone(), - )) - .arced(); - PhysicalPlan::ReduceMerge(ReduceMerge::new(split_op)).arced() + ) + .arced() }; let second_stage_agg = PhysicalPlan::Aggregate(Aggregate::new( @@ -641,24 +667,16 @@ pub(super) fn translate_single_logical_node( if num_left_partitions != num_partitions || (num_partitions > 1 && !is_left_hash_partitioned) { - let split_op = PhysicalPlan::FanoutByHash(FanoutByHash::new( - left_physical, - num_partitions, - left_on.clone(), - )); left_physical = - PhysicalPlan::ReduceMerge(ReduceMerge::new(split_op.into())).arced(); + build_exchange_op(left_physical, num_partitions, left_on.clone()) + .arced(); } if num_right_partitions != num_partitions || (num_partitions > 1 && !is_right_hash_partitioned) { - let split_op = PhysicalPlan::FanoutByHash(FanoutByHash::new( - right_physical, - num_partitions, - right_on.clone(), - )); right_physical = - PhysicalPlan::ReduceMerge(ReduceMerge::new(split_op.into())).arced(); + build_exchange_op(right_physical, num_partitions, right_on.clone()) + .arced(); } Ok(PhysicalPlan::HashJoin(HashJoin::new( left_physical, diff --git a/src/daft-scheduler/src/scheduler.rs b/src/daft-scheduler/src/scheduler.rs index 709dd8ff4d..04fa87c2d3 100644 --- a/src/daft-scheduler/src/scheduler.rs +++ b/src/daft-scheduler/src/scheduler.rs @@ -784,5 +784,45 @@ fn physical_plan_to_partition_tasks( physical_plan_to_partition_tasks(input, py, psets)?, lance_info, ), + PhysicalPlan::ExchangeOp(ExchangeOp { + input, + strategy: ExchangeOpStrategy::FullyMaterializingPull { target_spec }, + }) => { + let upstream_iter = physical_plan_to_partition_tasks(input, py, psets)?; + let partition_by_pyexprs: Vec = target_spec + .partition_by() + .iter() + .map(|expr| PyExpr::from(expr.clone())) + .collect(); + let py_iter = py + .import_bound(pyo3::intern!(py, "daft.execution.physical_plan"))? + .getattr(pyo3::intern!(py, "fully_materializing_exchange_op"))? + .call1(( + upstream_iter, + partition_by_pyexprs, + target_spec.num_partitions(), + ))?; + Ok(py_iter.into()) + } + PhysicalPlan::ExchangeOp(ExchangeOp { + input, + strategy: ExchangeOpStrategy::FullyMaterializingPush { target_spec }, + }) => { + let upstream_iter = physical_plan_to_partition_tasks(input, py, psets)?; + let partition_by_pyexprs: Vec = target_spec + .partition_by() + .iter() + .map(|expr| PyExpr::from(expr.clone())) + .collect(); + let py_iter = py + .import_bound(pyo3::intern!(py, "daft.execution.physical_plan"))? + .getattr(pyo3::intern!(py, "fully_materializing_push_exchange_op"))? + .call1(( + upstream_iter, + partition_by_pyexprs, + target_spec.num_partitions(), + ))?; + Ok(py_iter.into()) + } } }