From 716c0e70dab876044bd61aef9e0b1a3de564121b Mon Sep 17 00:00:00 2001 From: clarkzinzow Date: Thu, 18 Jan 2024 14:58:42 -0800 Subject: [PATCH] Add MergeJoinTaskTracker abstraction to clean things up. --- daft/execution/physical_plan.py | 217 ++++++++++++++++++++++---------- 1 file changed, 151 insertions(+), 66 deletions(-) diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index 35dd7ac27b..fc32e579b9 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -19,7 +19,7 @@ import math import pathlib from collections import deque -from typing import Generator, Iterator, List, TypeVar, Union, cast +from typing import Generator, Generic, Iterable, Iterator, TypeVar, Union from daft.context import get_context from daft.daft import ( @@ -403,13 +403,133 @@ def broadcast_join( return +class MergeJoinTaskTracker(Generic[PartitionT]): + """ + Tracks merge-join tasks for each larger-side partition. + + Merge-join tasks are added to the tracker, and the tracker handles empty tasks, finalizing PartitionTaskBuilders, + determining whether tasks are ready to be executed, checking whether tasks are done, and deciding whether a coalesce + is needed. + """ + + def __init__(self, stage_id: int): + # Merge-join tasks that have not yet been finalized or yielded to the runner. We don't finalize a merge-join + # task until we have at least 2 non-empty merge-join tasks, at which point this task will be popped from + # _task_staging, finalized, and put into _finalized_tasks. + self._task_staging: dict[str, PartitionTaskBuilder[PartitionT]] = {} + # Merge-join tasks that have been finalized, but not yet yielded to the runner. + self._finalized_tasks: collections.defaultdict[ + str, deque[SingleOutputPartitionTask[PartitionT]] + ] = collections.defaultdict(deque) + # Merge-join tasks that have been yielded to the runner, and still need to be coalesced. + self._uncoalesced_tasks: collections.defaultdict[ + str, deque[SingleOutputPartitionTask[PartitionT]] + ] = collections.defaultdict(deque) + # Larger-side partitions that have been finalized, i.e. we're guaranteed that no more smaller-side partitions + # will be added to the tracker for this partition. + self._finalized: dict[str, bool] = {} + self._stage_id = stage_id + + def add_task(self, part_id: str, task: PartitionTaskBuilder[PartitionT]) -> None: + """ + Add a merge-join task to the tracker for the provided larger-side partition. + + This task needs to be unfinalized, i.e. a PartitionTaskBuilder. + """ + # If no merge-join tasks have been added to the tracker yet for this partition, or we have an empty task in + # staging, add the unfinalized merge-join task to staging. + if not self._is_contained(part_id) or ( + part_id in self._task_staging and self._task_staging[part_id].is_empty() + ): + self._task_staging[part_id] = task + # Otherwise, we have at least 2 (probably) non-empty merge-join tasks, so we finalize the new task and add it + # to _finalized_tasks. If the new task is empty, then we drop it (we already have at least one task for this + # partition, so no use in keeping an additional empty task around). + elif not task.is_empty(): + # If we have a task in staging, we know from the first if statement that it's non-empty, so we finalize it + # and add it to _finalized_tasks. + if part_id in self._task_staging: + self._finalized_tasks[part_id].append( + self._task_staging.pop(part_id).finalize_partition_task_single_output(self._stage_id) + ) + self._finalized_tasks[part_id].append(task.finalize_partition_task_single_output(self._stage_id)) + + def finalize(self, part_id: str) -> None: + """ + Indicates to the tracker that we are done adding merge-join tasks for this partition. + """ + # All finalized tasks should have been yielded before the tracker.finalize() call. + finalized_tasks = self._finalized_tasks.pop(part_id, deque()) + assert len(finalized_tasks) == 0 + + self._finalized[part_id] = True + + def yield_ready( + self, part_id: str + ) -> Iterator[SingleOutputPartitionTask[PartitionT] | PartitionTaskBuilder[PartitionT]]: + """ + Returns an iterator of all tasks for this partition that are ready for execution. Each merge-join task will be + yielded once, even across multiple calls. + """ + assert self._is_contained(part_id) + if part_id in self._finalized_tasks: + # Yield the finalized tasks and add them to the uncoalesced queue. + while self._finalized_tasks[part_id]: + task = self._finalized_tasks[part_id].popleft() + yield task + self._uncoalesced_tasks[part_id].append(task) + elif self._finalized.get(part_id, False) and part_id in self._task_staging: + # If the tracker has been finalized for this partition, we can yield unfinalized tasks directly from + # staging since no future tasks will be added. + yield self._task_staging.pop(part_id) + + def pop_uncoalesced(self, part_id: str) -> deque[SingleOutputPartitionTask[PartitionT]] | None: + """ + Returns all tasks for this partition that need to be coalesced. If this partition only involved a single + merge-join task (i.e. we don't need to coalesce), this this function will return None. + + NOTE: tracker.finalize(part_id) must be called before this function. + """ + assert self._finalized[part_id] + return self._uncoalesced_tasks.pop(part_id, None) + + def all_tasks_done_for_partition(self, part_id: str) -> bool: + """ + Return whether all merge-join tasks for this partition are done. + """ + assert self._is_contained(part_id) + if part_id in self._task_staging: + # Unfinalized tasks are trivially "done". + return True + return all( + task.done() + for task in itertools.chain( + self._finalized_tasks.get(part_id, deque()), self._uncoalesced_tasks.get(part_id, deque()) + ) + ) + + def all_tasks_done(self) -> bool: + """ + Return whether all merge-join tasks for all partitions are done. + """ + return all( + self.all_tasks_done_for_partition(part_id) + for part_id in itertools.chain( + self._uncoalesced_tasks.keys(), self._finalized_tasks.keys(), self._task_staging.keys() + ) + ) + + def _is_contained(self, part_id: str) -> bool: + """ + Return whether the provided partition is being tracked by this tracker. + """ + return part_id in self._task_staging or part_id in self._finalized_tasks or part_id in self._uncoalesced_tasks + + def _emit_merge_joins_on_window( next_part: SingleOutputPartitionTask[PartitionT], other_window: deque[SingleOutputPartitionTask[PartitionT]], - merge_join_partition_tasks: collections.defaultdict[ - str, list[PartitionTaskBuilder[PartitionT] | PartitionTask[PartitionT]] - ], - stage_id: int, + merge_join_task_tracker: MergeJoinTaskTracker[PartitionT], flipped: bool, next_is_larger: bool, left_on: ExpressionsProjection, @@ -431,7 +551,7 @@ def _emit_merge_joins_on_window( if flipped: inputs = list(reversed(inputs)) partial_metadatas = list(reversed(partial_metadatas)) - join_task: PartitionTaskBuilder[PartitionT] | PartitionTask[PartitionT] = PartitionTaskBuilder[PartitionT]( + join_task = PartitionTaskBuilder[PartitionT]( inputs=inputs, partial_metadatas=partial_metadatas, resource_request=ResourceRequest(memory_bytes=memory_bytes), @@ -446,31 +566,8 @@ def _emit_merge_joins_on_window( # Add to new merge-join step to tracked steps for this larger-side partition, and possibly start finalizing + # emitting non-empty join steps if there are now more than one. part_id = next_part.id() if next_is_larger else other_next_part.id() - tasks = merge_join_partition_tasks[part_id] - if len(tasks) == 1: - # If the only merge-join task we have for this partition is empty, remove it and replace it (below) - # with our new (probably) non-empty merge-join task. - if tasks[0].is_empty(): - tasks.pop() - elif not join_task.is_empty(): - # There are currently two (probably non-empty) merge-join tasks so we'll need to issue a coalesce - # partition task later, so finalize and yield the first merge-join task now. - # The second will be finalized and yielded below. - unfinalized_task = tasks[0] - assert isinstance(unfinalized_task, PartitionTaskBuilder) - tasks[0] = unfinalized_task.finalize_partition_task_single_output(stage_id) - yield tasks[0] - if not join_task.is_empty() and len(tasks) >= 1: - # There are at least two (probably non-empty) merge-join tasks so we'll need to issue a coalesce partition - # task later, so finalize and yield the new merge-join task now. - assert isinstance(join_task, PartitionTaskBuilder) - join_task = join_task.finalize_partition_task_single_output(stage_id) - yield join_task - # Add merge-join task to appropriate group. - if not join_task.is_empty() or not tasks: - # If merge-join task is empty and there are no merge-join tasks yet for this partition, we still add the - # empty task in case we need to propagate an empty partition later. - tasks.append(join_task) + merge_join_task_tracker.add_task(part_id, join_task) + yield from merge_join_task_tracker.yield_ready(part_id) def _memory_bytes_for_merge( @@ -511,6 +608,8 @@ def merge_join_sorted( larger_plan = left_plan if left_is_larger else right_plan smaller_plan = right_plan if left_is_larger else left_plan + stage_id = next(stage_id_counter) + # In-progress tasks for larger side of join. larger_requests: deque[SingleOutputPartitionTask[PartitionT]] = deque() # In-progress tasks for smaller side of join. @@ -522,17 +621,14 @@ def merge_join_sorted( # larger-side materialized partition has a higher upper bound, which suggests that this smaller-side partition won't # be able to intersect with any future larger-side partitions. smaller_window: deque[SingleOutputPartitionTask[PartitionT]] = deque() - # A map from IDs of input partitions from the larger side of the join to a list of merge-join partition tasks that - # were emitted on said input partition. + # Tracks merge-join partition tasks emitted for each partition on the larger side of the join. # Once all merge-join tasks are done, the corresponding output partitions will be coalesced together. - # If the merge-join task list only contains a single element, it will be an unfinalized PartitionTaskBuilder, - # the coalescing step will be skipped, and this merge-join task will be yielded without finalizing in order to - # allow fusion with downstream tasks; otherwise, the list will contain finalized PartitionTasks. - merge_join_partition_tasks: collections.defaultdict[ - str, list[PartitionTaskBuilder[PartitionT] | PartitionTask[PartitionT]] - ] = collections.defaultdict(list) + # If only a single merge-join task is emitted for a larger-side partition, it will be an unfinalized + # PartitionTaskBuilder, the coalescing step will be skipped, and this merge-join task will be yielded without + # finalizing in order to allow fusion with downstream tasks; otherwise, the tracker will contain finalized + # PartitionTasks. + merge_join_task_tracker: MergeJoinTaskTracker[PartitionT] = MergeJoinTaskTracker(stage_id) - stage_id = next(stage_id_counter) yield_smaller = True smaller_done = False larger_done = False @@ -547,8 +643,7 @@ def merge_join_sorted( yield from _emit_merge_joins_on_window( next_part, larger_window, - merge_join_partition_tasks, - stage_id, + merge_join_task_tracker, left_is_larger, False, left_on, @@ -563,8 +658,7 @@ def merge_join_sorted( yield from _emit_merge_joins_on_window( next_part, smaller_window, - merge_join_partition_tasks, - stage_id, + merge_join_task_tracker, not left_is_larger, True, left_on, @@ -602,32 +696,27 @@ def merge_join_sorted( # Only finalize merge-join tasks for larger-side partition if all outputs are done OR there's only a # single finalized output (in which case we yield and unfinalized merge-join task to allow downstream # fusion with it). - all( - isinstance(task, PartitionTaskBuilder) or task.done() - for task in merge_join_partition_tasks[larger_window[0].id()] - ) + merge_join_task_tracker.all_tasks_done_for_partition(larger_window[0].id()) ) ): done_larger_part = larger_window.popleft() - tasks = merge_join_partition_tasks.pop(done_larger_part.id()) - assert len(tasks) > 0 - if len(tasks) == 1: - # Only one output partition, so no coalesce needed; yield the merge-join partition task without - # finalizing to allow for downstream fusing. - # This output partition may be predetermined to be empty. - yield tasks[0] + part_id = done_larger_part.id() + merge_join_task_tracker.finalize(part_id) + yield from merge_join_task_tracker.yield_ready(part_id) + tasks = merge_join_task_tracker.pop_uncoalesced(part_id) + if tasks is None: + # Only one output partition, so no coalesce needed. continue - tasks_ = cast(List[SingleOutputPartitionTask], tasks) # At least two (probably non-empty) merge-join tasks for this group, so need to coalesce. # NOTE: We guarantee in _emit_merge_joins_on_window that any group containing 2 or more partition tasks # will only contain non-guaranteed-empty partitions; i.e., we'll need to execute a task to determine if # they actually are empty, so we just issue the coalesce task. # TODO(Clark): Elide coalescing by emitting a single merge-join task per larger-side partition, including as # input all intersecting partitions from the smaller side of the join. - size_bytes = _memory_bytes_for_coalesce(tasks_) + size_bytes = _memory_bytes_for_coalesce(tasks) coalesce_task = PartitionTaskBuilder[PartitionT]( - inputs=[task.partition() for task in tasks_], - partial_metadatas=[task.partition_metadata() for task in tasks_], + inputs=[task.partition() for task in tasks], + partial_metadatas=[task.partition_metadata() for task in tasks], resource_request=ResourceRequest(memory_bytes=size_bytes), ).add_instruction( instruction=execution_step.ReduceMerge(), @@ -680,14 +769,10 @@ def merge_join_sorted( larger_done = True # We might still be waiting for some merge-join tasks to complete whose output we still need # to coalesce. - elif any( - isinstance(task, SingleOutputPartitionTask) and not task.done() - for tasks in merge_join_partition_tasks.values() - for task in tasks - ): + elif not merge_join_task_tracker.all_tasks_done(): logger.debug( "merge join blocked on completion of merge join tasks (pre-coalesce).\nMerge-join tasks: %s", - list(merge_join_partition_tasks.values()), + list(merge_join_task_tracker._finalized_tasks.values()), ) yield None # Otherwise, all join inputs are done and all merge-join tasks are done, so we are entirely done emitting @@ -709,7 +794,7 @@ def _is_strictly_bounded_above_by( return lower_boundaries.is_strictly_bounded_above_by(upper_boundaries) -def _memory_bytes_for_coalesce(input_parts: list[SingleOutputPartitionTask[PartitionT]]) -> int | None: +def _memory_bytes_for_coalesce(input_parts: Iterable[SingleOutputPartitionTask[PartitionT]]) -> int | None: # Calculate memory request for task. size_bytes_per_task = [task.partition_metadata().size_bytes for task in input_parts] non_null_size_bytes_per_task = [size for size in size_bytes_per_task if size is not None]