Skip to content

Commit

Permalink
[FEAT] [Join Optimizations] Add broadcast join. (#1706)
Browse files Browse the repository at this point in the history
This PR adds a broadcast join implementation as a new join strategy,
where all partitions of a small table are broadcasted to each partition
in the larger table, such that we do a local (hash) join of the entire
small table with each individual partition of the larger table.

## Query Planning

The query planner chooses the broadcast join as its join strategy if one
of the sides of the join is smaller than a preconfigured broadcasting
threshold (set to 10 MiB by default, but is user-configurable).

**Note:** If the smaller side of the join is the right side, we invert
the join for planning and scheduling simplicity so we can always
broadcast the left side; we then swap back to the correct join ordering
when performing the local joins. This means that we always form the
probe table on the left side of the join; a future optimization
(applicable to both the broadcast join and the hash join) would be to
have local joins build the probe table on the smaller side while
preserving the expected column ordering. We would still need to always
build the probe table on the left side of the join if we need to
preserve the row-ordering of the right side of the join, e.g. if the
right side of the join is range-partitioned AND we're doing a broadcast
join.

## Query Scheduling

All partitions for the broadcasting side of the join are first
materialized. Then, as each partition on the receiving side of the join
materialize, we dispatch a hash join task joining all broadcaster
partitions with that single receiving-side partition.

## TODOs

- [x] Test coverage.
- [ ] (Follow-up?) TPC-H benchmarking demonstrating speedup due to use
of broadcast join.
- [ ] (Follow-up) In local joins, build the probe table on the smaller
side of the join.
- [ ] (Follow-up) Add table size approximations for operators that
affect cardinality.
  • Loading branch information
clarkzinzow authored Dec 11, 2023
1 parent 7de0bbe commit 2739cc0
Show file tree
Hide file tree
Showing 23 changed files with 482 additions and 58 deletions.
8 changes: 6 additions & 2 deletions daft/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def set_config(
config: PyDaftConfig | None = None,
merge_scan_tasks_min_size_bytes: int | None = None,
merge_scan_tasks_max_size_bytes: int | None = None,
broadcast_join_size_bytes_threshold: int | None = None,
) -> DaftContext:
"""Globally sets various configuration parameters which control various aspects of Daft execution
Expand All @@ -203,10 +204,12 @@ def set_config(
that the old (current) config should be used.
merge_scan_tasks_min_size_bytes: Minimum size in bytes when merging ScanTasks when reading files from storage.
Increasing this value will make Daft perform more merging of files into a single partition before yielding,
which leads to bigger but fewer partitions. (Defaults to 64MB)
which leads to bigger but fewer partitions. (Defaults to 64 MiB)
merge_scan_tasks_max_size_bytes: Maximum size in bytes when merging ScanTasks when reading files from storage.
Increasing this value will increase the upper bound of the size of merged ScanTasks, which leads to bigger but
fewer partitions. (Defaults to 512MB)
fewer partitions. (Defaults to 512 MiB)
broadcast_join_size_bytes_threshold: If one side of a join is smaller than this threshold, a broadcast join will be used.
Default is 10 MiB.
"""
ctx = get_context()
if ctx._runner is not None:
Expand All @@ -220,6 +223,7 @@ def set_config(
new_daft_config = old_daft_config.with_config_values(
merge_scan_tasks_min_size_bytes=merge_scan_tasks_min_size_bytes,
merge_scan_tasks_max_size_bytes=merge_scan_tasks_max_size_bytes,
broadcast_join_size_bytes_threshold=broadcast_join_size_bytes_threshold,
)

ctx.daft_config = new_daft_config
Expand Down
5 changes: 4 additions & 1 deletion daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1058,7 +1058,7 @@ class LogicalPlanBuilder:

@staticmethod
def in_memory_scan(
partition_key: str, cache_entry: PartitionCacheEntry, schema: PySchema, num_partitions: int
partition_key: str, cache_entry: PartitionCacheEntry, schema: PySchema, num_partitions: int, size_bytes: int
) -> LogicalPlanBuilder: ...
@staticmethod
def table_scan_with_scan_operator(scan_operator: ScanOperatorHandle) -> LogicalPlanBuilder: ...
Expand Down Expand Up @@ -1102,11 +1102,14 @@ class PyDaftConfig:
self,
merge_scan_tasks_min_size_bytes: int | None = None,
merge_scan_tasks_max_size_bytes: int | None = None,
broadcast_join_size_bytes_threshold: int | None = None,
) -> PyDaftConfig: ...
@property
def merge_scan_tasks_min_size_bytes(self): ...
@property
def merge_scan_tasks_max_size_bytes(self): ...
@property
def broadcast_join_size_bytes_threshold(self): ...

def build_type() -> str: ...
def version() -> str: ...
Expand Down
30 changes: 25 additions & 5 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,16 @@ def _builder(self) -> LogicalPlanBuilder:
return self.__builder
else:
num_partitions = self._result_cache.num_partitions()
size_bytes = self._result_cache.size_bytes()
# Partition set should always be set on cache entry.
assert num_partitions is not None, "Partition set should always be set on cache entry"
assert (
num_partitions is not None and size_bytes is not None
), "Partition set should always be set on cache entry"
return self.__builder.from_in_memory_scan(
self._result_cache, self.__builder.schema(), num_partitions=num_partitions
self._result_cache,
self.__builder.schema(),
num_partitions=num_partitions,
size_bytes=size_bytes,
)

def _get_current_builder(self) -> LogicalPlanBuilder:
Expand Down Expand Up @@ -273,7 +279,11 @@ def _from_tables(cls, *parts: MicroPartition) -> "DataFrame":

context = get_context()
cache_entry = context.runner().put_partition_set_into_cache(result_pset)
builder = LogicalPlanBuilder.from_in_memory_scan(cache_entry, parts[0].schema(), result_pset.num_partitions())
size_bytes = result_pset.size_bytes()
assert size_bytes is not None, "In-memory data should always have non-None size in bytes"
builder = LogicalPlanBuilder.from_in_memory_scan(
cache_entry, parts[0].schema(), result_pset.num_partitions(), size_bytes
)
return cls(builder)

###
Expand Down Expand Up @@ -1233,8 +1243,13 @@ def _from_ray_dataset(cls, ds: "RayDataset") -> "DataFrame":

partition_set, schema = ray_runner_io.partition_set_from_ray_dataset(ds)
cache_entry = context.runner().put_partition_set_into_cache(partition_set)
size_bytes = partition_set.size_bytes()
assert size_bytes is not None, "In-memory data should always have non-None size in bytes"
builder = LogicalPlanBuilder.from_in_memory_scan(
cache_entry, schema=schema, num_partitions=partition_set.num_partitions()
cache_entry,
schema=schema,
num_partitions=partition_set.num_partitions(),
size_bytes=size_bytes,
)
return cls(builder)

Expand Down Expand Up @@ -1300,8 +1315,13 @@ def _from_dask_dataframe(cls, ddf: "dask.DataFrame") -> "DataFrame":

partition_set, schema = ray_runner_io.partition_set_from_dask_dataframe(ddf)
cache_entry = context.runner().put_partition_set_into_cache(partition_set)
size_bytes = partition_set.size_bytes()
assert size_bytes is not None, "In-memory data should always have non-None size in bytes"
builder = LogicalPlanBuilder.from_in_memory_scan(
cache_entry, schema=schema, num_partitions=partition_set.num_partitions()
cache_entry,
schema=schema,
num_partitions=partition_set.num_partitions(),
size_bytes=size_bytes,
)
return cls(builder)

Expand Down
13 changes: 12 additions & 1 deletion daft/execution/execution_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,12 +629,23 @@ class Join(SingleOutputInstruction):
left_on: ExpressionsProjection
right_on: ExpressionsProjection
how: JoinType
is_swapped: bool

def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
return self._join(inputs)

def _join(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
[left, right] = inputs
# All inputs except for the last are the left side of the join, in order to support left-broadcasted joins.
*lefts, right = inputs
if len(lefts) > 1:
# NOTE: MicroPartition concats don't concatenate the underlying column arrays, since MicroPartitions are chunked.
left = MicroPartition.concat(lefts)
else:
left = lefts[0]
if self.is_swapped:
# Swap left/right back.
# We don't need to swap left_on and right_on since those were never swapped in the first place.
left, right = right, left
result = left.join(
right,
left_on=self.left_on,
Expand Down
130 changes: 129 additions & 1 deletion daft/execution/physical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def pipeline_instruction(
)


def join(
def hash_join(
left_plan: InProgressPhysicalPlan[PartitionT],
right_plan: InProgressPhysicalPlan[PartitionT],
left_on: ExpressionsProjection,
Expand Down Expand Up @@ -233,6 +233,7 @@ def join(
left_on=left_on,
right_on=right_on,
how=how,
is_swapped=False,
)
)
yield join_step
Expand Down Expand Up @@ -271,6 +272,133 @@ def join(
return


def _create_join_step(
broadcaster_parts: deque[SingleOutputPartitionTask[PartitionT]],
receiver_part: SingleOutputPartitionTask[PartitionT],
left_on: ExpressionsProjection,
right_on: ExpressionsProjection,
how: JoinType,
is_swapped: bool,
) -> PartitionTaskBuilder[PartitionT]:
# Calculate memory request for task.
broadcaster_size_bytes_ = 0
broadcaster_partitions = []
broadcaster_partition_metadatas = []
null_count = 0
for next_broadcaster in broadcaster_parts:
next_broadcaster_partition_metadata = next_broadcaster.partition_metadata()
if next_broadcaster_partition_metadata is None or next_broadcaster_partition_metadata.size_bytes is None:
null_count += 1
else:
broadcaster_size_bytes_ += next_broadcaster_partition_metadata.size_bytes
broadcaster_partitions.append(next_broadcaster.partition())
broadcaster_partition_metadatas.append(next_broadcaster_partition_metadata)
if null_count == len(broadcaster_parts):
broadcaster_size_bytes = None
elif null_count > 0:
# Impute null size estimates with mean of non-null estimates.
broadcaster_size_bytes = broadcaster_size_bytes_ + math.ceil(
null_count * broadcaster_size_bytes_ / (len(broadcaster_parts) - null_count)
)
else:
broadcaster_size_bytes = broadcaster_size_bytes_
receiver_size_bytes = receiver_part.partition_metadata().size_bytes
if broadcaster_size_bytes is None and receiver_size_bytes is None:
size_bytes = None
elif broadcaster_size_bytes is None and receiver_size_bytes is not None:
# Use 1.25x the receiver side as the memory request, assuming that receiver side is ~4x larger than the broadcaster side.
size_bytes = int(1.25 * receiver_size_bytes)
elif receiver_size_bytes is None and broadcaster_size_bytes is not None:
# Use 4x the broadcaster side as the memory request, assuming that receiver side is ~4x larger than the broadcaster side.
size_bytes = 4 * broadcaster_size_bytes
elif broadcaster_size_bytes is not None and receiver_size_bytes is not None:
size_bytes = broadcaster_size_bytes + receiver_size_bytes

return PartitionTaskBuilder[PartitionT](
inputs=broadcaster_partitions + [receiver_part.partition()],
partial_metadatas=list(broadcaster_partition_metadatas + [receiver_part.partition_metadata()]),
resource_request=ResourceRequest(memory_bytes=size_bytes),
).add_instruction(
instruction=execution_step.Join(
left_on=left_on,
right_on=right_on,
how=how,
is_swapped=is_swapped,
)
)


def broadcast_join(
broadcaster_plan: InProgressPhysicalPlan[PartitionT],
receiver_plan: InProgressPhysicalPlan[PartitionT],
left_on: ExpressionsProjection,
right_on: ExpressionsProjection,
how: JoinType,
is_swapped: bool,
) -> InProgressPhysicalPlan[PartitionT]:
"""Broadcast join all partitions from the broadcaster child plan to each partition in the receiver child plan."""

# Materialize the steps from the broadcaster and receiver sources to get partitions.
# As the receiver-side materializations complete, emit new steps to join each broadcaster and receiver partition.
stage_id = next(stage_id_counter)
broadcaster_requests: deque[SingleOutputPartitionTask[PartitionT]] = deque()
broadcaster_parts: deque[SingleOutputPartitionTask[PartitionT]] = deque()

# First, fully materialize the broadcasting side (broadcaster side) of the join.
while True:
# Moved completed partition tasks in the broadcaster side of the join to the materialized partition set.
while broadcaster_requests and broadcaster_requests[0].done():
broadcaster_parts.append(broadcaster_requests.popleft())

# Execute single child step to pull in more broadcaster-side partitions.
try:
step = next(broadcaster_plan)
if isinstance(step, PartitionTaskBuilder):
step = step.finalize_partition_task_single_output(stage_id=stage_id)
broadcaster_requests.append(step)
yield step
except StopIteration:
if broadcaster_requests:
logger.debug(
"broadcast join blocked on completion of broadcasting side of join.\n broadcaster sources: %s",
broadcaster_requests,
)
yield None
else:
break

# Second, broadcast materialized partitions to receiver side of join, as it materializes.
receiver_requests: deque[SingleOutputPartitionTask[PartitionT]] = deque()

while True:
receiver_parts: deque[SingleOutputPartitionTask[PartitionT]] = deque()
# Moved completed partition tasks in the receiver side of the join to the materialized partition set.
while receiver_requests and receiver_requests[0].done():
receiver_parts.append(receiver_requests.popleft())

# Emit join steps for newly materialized partitions.
# Broadcast all broadcaster partitions to each new receiver partition that was materialized on this dispatch loop.
for receiver_part in receiver_parts:
yield _create_join_step(broadcaster_parts, receiver_part, left_on, right_on, how, is_swapped)

# Execute single child step to pull in more input partitions.
try:
step = next(receiver_plan)
if isinstance(step, PartitionTaskBuilder):
step = step.finalize_partition_task_single_output(stage_id=stage_id)
receiver_requests.append(step)
yield step
except StopIteration:
if receiver_requests:
logger.debug(
"broadcast join blocked on completion of receiver side of join.\n receiver sources: %s",
receiver_requests,
)
yield None
else:
return


def concat(
top_plan: InProgressPhysicalPlan[PartitionT], bottom_plan: InProgressPhysicalPlan[PartitionT]
) -> InProgressPhysicalPlan[PartitionT]:
Expand Down
24 changes: 22 additions & 2 deletions daft/execution/rust_physical_plan_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def reduce_merge(
return physical_plan.reduce(input, reduce_instruction)


def join(
def hash_join(
input: physical_plan.InProgressPhysicalPlan[PartitionT],
right: physical_plan.InProgressPhysicalPlan[PartitionT],
left_on: list[PyExpr],
Expand All @@ -197,7 +197,7 @@ def join(
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
left_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in left_on])
right_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in right_on])
return physical_plan.join(
return physical_plan.hash_join(
left_plan=input,
right_plan=right,
left_on=left_on_expr_proj,
Expand All @@ -206,6 +206,26 @@ def join(
)


def broadcast_join(
broadcaster: physical_plan.InProgressPhysicalPlan[PartitionT],
receiver: physical_plan.InProgressPhysicalPlan[PartitionT],
left_on: list[PyExpr],
right_on: list[PyExpr],
join_type: JoinType,
is_swapped: bool,
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
left_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in left_on])
right_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in right_on])
return physical_plan.broadcast_join(
broadcaster_plan=broadcaster,
receiver_plan=receiver,
left_on=left_on_expr_proj,
right_on=right_on_expr_proj,
how=join_type,
is_swapped=is_swapped,
)


def write_file(
input: physical_plan.InProgressPhysicalPlan[PartitionT],
file_format: FileFormat,
Expand Down
3 changes: 3 additions & 0 deletions daft/io/file_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,12 @@ def from_glob_path(path: str, io_config: Optional[IOConfig] = None) -> DataFrame
file_infos_table = MicroPartition._from_pytable(file_infos.to_table())
partition = LocalPartitionSet({0: file_infos_table})
cache_entry = context.runner().put_partition_set_into_cache(partition)
size_bytes = partition.size_bytes()
assert size_bytes is not None, "In-memory data should always have non-None size in bytes"
builder = LogicalPlanBuilder.from_in_memory_scan(
cache_entry,
schema=file_infos_table.schema(),
num_partitions=partition.num_partitions(),
size_bytes=size_bytes,
)
return DataFrame(builder)
6 changes: 4 additions & 2 deletions daft/logical/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,11 @@ def optimize(self) -> LogicalPlanBuilder:

@classmethod
def from_in_memory_scan(
cls, partition: PartitionCacheEntry, schema: Schema, num_partitions: int
cls, partition: PartitionCacheEntry, schema: Schema, num_partitions: int, size_bytes: int
) -> LogicalPlanBuilder:
builder = _LogicalPlanBuilder.in_memory_scan(partition.key, partition, schema._schema, num_partitions)
builder = _LogicalPlanBuilder.in_memory_scan(
partition.key, partition, schema._schema, num_partitions, size_bytes
)
return cls(builder)

@classmethod
Expand Down
7 changes: 7 additions & 0 deletions daft/runners/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ def has_partition(self, idx: PartID) -> bool:
def __len__(self) -> int:
raise NotImplementedError()

@abstractmethod
def size_bytes(self) -> int | None:
raise NotImplementedError()

@abstractmethod
def num_partitions(self) -> int:
raise NotImplementedError()
Expand Down Expand Up @@ -208,6 +212,9 @@ def __setstate__(self, key):
def num_partitions(self) -> int | None:
return self.value.num_partitions() if self.value is not None else None

def size_bytes(self) -> int | None:
return self.value.size_bytes() if self.value is not None else None


class PartitionSetCache:
def __init__(self) -> None:
Expand Down
Loading

0 comments on commit 2739cc0

Please sign in to comment.