Skip to content

Commit

Permalink
Add MergeJoinTaskTracker abstraction to clean things up.
Browse files Browse the repository at this point in the history
  • Loading branch information
clarkzinzow committed Jan 18, 2024
1 parent de95c6f commit 716c0e7
Showing 1 changed file with 151 additions and 66 deletions.
217 changes: 151 additions & 66 deletions daft/execution/physical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import math
import pathlib
from collections import deque
from typing import Generator, Iterator, List, TypeVar, Union, cast
from typing import Generator, Generic, Iterable, Iterator, TypeVar, Union

from daft.context import get_context
from daft.daft import (
Expand Down Expand Up @@ -403,13 +403,133 @@ def broadcast_join(
return


class MergeJoinTaskTracker(Generic[PartitionT]):
"""
Tracks merge-join tasks for each larger-side partition.
Merge-join tasks are added to the tracker, and the tracker handles empty tasks, finalizing PartitionTaskBuilders,
determining whether tasks are ready to be executed, checking whether tasks are done, and deciding whether a coalesce
is needed.
"""

def __init__(self, stage_id: int):
# Merge-join tasks that have not yet been finalized or yielded to the runner. We don't finalize a merge-join
# task until we have at least 2 non-empty merge-join tasks, at which point this task will be popped from
# _task_staging, finalized, and put into _finalized_tasks.
self._task_staging: dict[str, PartitionTaskBuilder[PartitionT]] = {}
# Merge-join tasks that have been finalized, but not yet yielded to the runner.
self._finalized_tasks: collections.defaultdict[
str, deque[SingleOutputPartitionTask[PartitionT]]
] = collections.defaultdict(deque)
# Merge-join tasks that have been yielded to the runner, and still need to be coalesced.
self._uncoalesced_tasks: collections.defaultdict[
str, deque[SingleOutputPartitionTask[PartitionT]]
] = collections.defaultdict(deque)
# Larger-side partitions that have been finalized, i.e. we're guaranteed that no more smaller-side partitions
# will be added to the tracker for this partition.
self._finalized: dict[str, bool] = {}
self._stage_id = stage_id

def add_task(self, part_id: str, task: PartitionTaskBuilder[PartitionT]) -> None:
"""
Add a merge-join task to the tracker for the provided larger-side partition.
This task needs to be unfinalized, i.e. a PartitionTaskBuilder.
"""
# If no merge-join tasks have been added to the tracker yet for this partition, or we have an empty task in
# staging, add the unfinalized merge-join task to staging.
if not self._is_contained(part_id) or (
part_id in self._task_staging and self._task_staging[part_id].is_empty()
):
self._task_staging[part_id] = task
# Otherwise, we have at least 2 (probably) non-empty merge-join tasks, so we finalize the new task and add it
# to _finalized_tasks. If the new task is empty, then we drop it (we already have at least one task for this
# partition, so no use in keeping an additional empty task around).
elif not task.is_empty():
# If we have a task in staging, we know from the first if statement that it's non-empty, so we finalize it
# and add it to _finalized_tasks.
if part_id in self._task_staging:
self._finalized_tasks[part_id].append(
self._task_staging.pop(part_id).finalize_partition_task_single_output(self._stage_id)
)
self._finalized_tasks[part_id].append(task.finalize_partition_task_single_output(self._stage_id))

def finalize(self, part_id: str) -> None:
"""
Indicates to the tracker that we are done adding merge-join tasks for this partition.
"""
# All finalized tasks should have been yielded before the tracker.finalize() call.
finalized_tasks = self._finalized_tasks.pop(part_id, deque())
assert len(finalized_tasks) == 0

self._finalized[part_id] = True

def yield_ready(
self, part_id: str
) -> Iterator[SingleOutputPartitionTask[PartitionT] | PartitionTaskBuilder[PartitionT]]:
"""
Returns an iterator of all tasks for this partition that are ready for execution. Each merge-join task will be
yielded once, even across multiple calls.
"""
assert self._is_contained(part_id)
if part_id in self._finalized_tasks:
# Yield the finalized tasks and add them to the uncoalesced queue.
while self._finalized_tasks[part_id]:
task = self._finalized_tasks[part_id].popleft()
yield task
self._uncoalesced_tasks[part_id].append(task)
elif self._finalized.get(part_id, False) and part_id in self._task_staging:
# If the tracker has been finalized for this partition, we can yield unfinalized tasks directly from
# staging since no future tasks will be added.
yield self._task_staging.pop(part_id)

def pop_uncoalesced(self, part_id: str) -> deque[SingleOutputPartitionTask[PartitionT]] | None:
"""
Returns all tasks for this partition that need to be coalesced. If this partition only involved a single
merge-join task (i.e. we don't need to coalesce), this this function will return None.
NOTE: tracker.finalize(part_id) must be called before this function.
"""
assert self._finalized[part_id]
return self._uncoalesced_tasks.pop(part_id, None)

def all_tasks_done_for_partition(self, part_id: str) -> bool:
"""
Return whether all merge-join tasks for this partition are done.
"""
assert self._is_contained(part_id)
if part_id in self._task_staging:
# Unfinalized tasks are trivially "done".
return True
return all(
task.done()
for task in itertools.chain(
self._finalized_tasks.get(part_id, deque()), self._uncoalesced_tasks.get(part_id, deque())
)
)

def all_tasks_done(self) -> bool:
"""
Return whether all merge-join tasks for all partitions are done.
"""
return all(
self.all_tasks_done_for_partition(part_id)
for part_id in itertools.chain(
self._uncoalesced_tasks.keys(), self._finalized_tasks.keys(), self._task_staging.keys()
)
)

def _is_contained(self, part_id: str) -> bool:
"""
Return whether the provided partition is being tracked by this tracker.
"""
return part_id in self._task_staging or part_id in self._finalized_tasks or part_id in self._uncoalesced_tasks


def _emit_merge_joins_on_window(
next_part: SingleOutputPartitionTask[PartitionT],
other_window: deque[SingleOutputPartitionTask[PartitionT]],
merge_join_partition_tasks: collections.defaultdict[
str, list[PartitionTaskBuilder[PartitionT] | PartitionTask[PartitionT]]
],
stage_id: int,
merge_join_task_tracker: MergeJoinTaskTracker[PartitionT],
flipped: bool,
next_is_larger: bool,
left_on: ExpressionsProjection,
Expand All @@ -431,7 +551,7 @@ def _emit_merge_joins_on_window(
if flipped:
inputs = list(reversed(inputs))
partial_metadatas = list(reversed(partial_metadatas))
join_task: PartitionTaskBuilder[PartitionT] | PartitionTask[PartitionT] = PartitionTaskBuilder[PartitionT](
join_task = PartitionTaskBuilder[PartitionT](
inputs=inputs,
partial_metadatas=partial_metadatas,
resource_request=ResourceRequest(memory_bytes=memory_bytes),
Expand All @@ -446,31 +566,8 @@ def _emit_merge_joins_on_window(
# Add to new merge-join step to tracked steps for this larger-side partition, and possibly start finalizing +
# emitting non-empty join steps if there are now more than one.
part_id = next_part.id() if next_is_larger else other_next_part.id()
tasks = merge_join_partition_tasks[part_id]
if len(tasks) == 1:
# If the only merge-join task we have for this partition is empty, remove it and replace it (below)
# with our new (probably) non-empty merge-join task.
if tasks[0].is_empty():
tasks.pop()
elif not join_task.is_empty():
# There are currently two (probably non-empty) merge-join tasks so we'll need to issue a coalesce
# partition task later, so finalize and yield the first merge-join task now.
# The second will be finalized and yielded below.
unfinalized_task = tasks[0]
assert isinstance(unfinalized_task, PartitionTaskBuilder)
tasks[0] = unfinalized_task.finalize_partition_task_single_output(stage_id)
yield tasks[0]
if not join_task.is_empty() and len(tasks) >= 1:
# There are at least two (probably non-empty) merge-join tasks so we'll need to issue a coalesce partition
# task later, so finalize and yield the new merge-join task now.
assert isinstance(join_task, PartitionTaskBuilder)
join_task = join_task.finalize_partition_task_single_output(stage_id)
yield join_task
# Add merge-join task to appropriate group.
if not join_task.is_empty() or not tasks:
# If merge-join task is empty and there are no merge-join tasks yet for this partition, we still add the
# empty task in case we need to propagate an empty partition later.
tasks.append(join_task)
merge_join_task_tracker.add_task(part_id, join_task)
yield from merge_join_task_tracker.yield_ready(part_id)


def _memory_bytes_for_merge(
Expand Down Expand Up @@ -511,6 +608,8 @@ def merge_join_sorted(
larger_plan = left_plan if left_is_larger else right_plan
smaller_plan = right_plan if left_is_larger else left_plan

stage_id = next(stage_id_counter)

# In-progress tasks for larger side of join.
larger_requests: deque[SingleOutputPartitionTask[PartitionT]] = deque()
# In-progress tasks for smaller side of join.
Expand All @@ -522,17 +621,14 @@ def merge_join_sorted(
# larger-side materialized partition has a higher upper bound, which suggests that this smaller-side partition won't
# be able to intersect with any future larger-side partitions.
smaller_window: deque[SingleOutputPartitionTask[PartitionT]] = deque()
# A map from IDs of input partitions from the larger side of the join to a list of merge-join partition tasks that
# were emitted on said input partition.
# Tracks merge-join partition tasks emitted for each partition on the larger side of the join.
# Once all merge-join tasks are done, the corresponding output partitions will be coalesced together.
# If the merge-join task list only contains a single element, it will be an unfinalized PartitionTaskBuilder,
# the coalescing step will be skipped, and this merge-join task will be yielded without finalizing in order to
# allow fusion with downstream tasks; otherwise, the list will contain finalized PartitionTasks.
merge_join_partition_tasks: collections.defaultdict[
str, list[PartitionTaskBuilder[PartitionT] | PartitionTask[PartitionT]]
] = collections.defaultdict(list)
# If only a single merge-join task is emitted for a larger-side partition, it will be an unfinalized
# PartitionTaskBuilder, the coalescing step will be skipped, and this merge-join task will be yielded without
# finalizing in order to allow fusion with downstream tasks; otherwise, the tracker will contain finalized
# PartitionTasks.
merge_join_task_tracker: MergeJoinTaskTracker[PartitionT] = MergeJoinTaskTracker(stage_id)

stage_id = next(stage_id_counter)
yield_smaller = True
smaller_done = False
larger_done = False
Expand All @@ -547,8 +643,7 @@ def merge_join_sorted(
yield from _emit_merge_joins_on_window(
next_part,
larger_window,
merge_join_partition_tasks,
stage_id,
merge_join_task_tracker,
left_is_larger,
False,
left_on,
Expand All @@ -563,8 +658,7 @@ def merge_join_sorted(
yield from _emit_merge_joins_on_window(
next_part,
smaller_window,
merge_join_partition_tasks,
stage_id,
merge_join_task_tracker,
not left_is_larger,
True,
left_on,
Expand Down Expand Up @@ -602,32 +696,27 @@ def merge_join_sorted(
# Only finalize merge-join tasks for larger-side partition if all outputs are done OR there's only a
# single finalized output (in which case we yield and unfinalized merge-join task to allow downstream
# fusion with it).
all(
isinstance(task, PartitionTaskBuilder) or task.done()
for task in merge_join_partition_tasks[larger_window[0].id()]
)
merge_join_task_tracker.all_tasks_done_for_partition(larger_window[0].id())
)
):
done_larger_part = larger_window.popleft()
tasks = merge_join_partition_tasks.pop(done_larger_part.id())
assert len(tasks) > 0
if len(tasks) == 1:
# Only one output partition, so no coalesce needed; yield the merge-join partition task without
# finalizing to allow for downstream fusing.
# This output partition may be predetermined to be empty.
yield tasks[0]
part_id = done_larger_part.id()
merge_join_task_tracker.finalize(part_id)
yield from merge_join_task_tracker.yield_ready(part_id)
tasks = merge_join_task_tracker.pop_uncoalesced(part_id)
if tasks is None:
# Only one output partition, so no coalesce needed.
continue
tasks_ = cast(List[SingleOutputPartitionTask], tasks)
# At least two (probably non-empty) merge-join tasks for this group, so need to coalesce.
# NOTE: We guarantee in _emit_merge_joins_on_window that any group containing 2 or more partition tasks
# will only contain non-guaranteed-empty partitions; i.e., we'll need to execute a task to determine if
# they actually are empty, so we just issue the coalesce task.
# TODO(Clark): Elide coalescing by emitting a single merge-join task per larger-side partition, including as
# input all intersecting partitions from the smaller side of the join.
size_bytes = _memory_bytes_for_coalesce(tasks_)
size_bytes = _memory_bytes_for_coalesce(tasks)
coalesce_task = PartitionTaskBuilder[PartitionT](
inputs=[task.partition() for task in tasks_],
partial_metadatas=[task.partition_metadata() for task in tasks_],
inputs=[task.partition() for task in tasks],
partial_metadatas=[task.partition_metadata() for task in tasks],
resource_request=ResourceRequest(memory_bytes=size_bytes),
).add_instruction(
instruction=execution_step.ReduceMerge(),
Expand Down Expand Up @@ -680,14 +769,10 @@ def merge_join_sorted(
larger_done = True
# We might still be waiting for some merge-join tasks to complete whose output we still need
# to coalesce.
elif any(
isinstance(task, SingleOutputPartitionTask) and not task.done()
for tasks in merge_join_partition_tasks.values()
for task in tasks
):
elif not merge_join_task_tracker.all_tasks_done():
logger.debug(
"merge join blocked on completion of merge join tasks (pre-coalesce).\nMerge-join tasks: %s",
list(merge_join_partition_tasks.values()),
list(merge_join_task_tracker._finalized_tasks.values()),
)
yield None
# Otherwise, all join inputs are done and all merge-join tasks are done, so we are entirely done emitting
Expand All @@ -709,7 +794,7 @@ def _is_strictly_bounded_above_by(
return lower_boundaries.is_strictly_bounded_above_by(upper_boundaries)


def _memory_bytes_for_coalesce(input_parts: list[SingleOutputPartitionTask[PartitionT]]) -> int | None:
def _memory_bytes_for_coalesce(input_parts: Iterable[SingleOutputPartitionTask[PartitionT]]) -> int | None:
# Calculate memory request for task.
size_bytes_per_task = [task.partition_metadata().size_bytes for task in input_parts]
non_null_size_bytes_per_task = [size for size in size_bytes_per_task if size is not None]
Expand Down

0 comments on commit 716c0e7

Please sign in to comment.