Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] [Join Optimizations] Add broadcast join. #1706

Merged
merged 4 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions daft/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def set_config(
config: PyDaftConfig | None = None,
merge_scan_tasks_min_size_bytes: int | None = None,
merge_scan_tasks_max_size_bytes: int | None = None,
broadcast_join_size_bytes_threshold: int | None = None,
) -> DaftContext:
"""Globally sets various configuration parameters which control various aspects of Daft execution

Expand All @@ -203,10 +204,12 @@ def set_config(
that the old (current) config should be used.
merge_scan_tasks_min_size_bytes: Minimum size in bytes when merging ScanTasks when reading files from storage.
Increasing this value will make Daft perform more merging of files into a single partition before yielding,
which leads to bigger but fewer partitions. (Defaults to 64MB)
which leads to bigger but fewer partitions. (Defaults to 64 MiB)
merge_scan_tasks_max_size_bytes: Maximum size in bytes when merging ScanTasks when reading files from storage.
Increasing this value will increase the upper bound of the size of merged ScanTasks, which leads to bigger but
fewer partitions. (Defaults to 512MB)
fewer partitions. (Defaults to 512 MiB)
broadcast_join_size_bytes_threshold: If one side of a join is smaller than this threshold, a broadcast join will be used.
Default is 10 MiB.
"""
ctx = get_context()
if ctx.disallow_set_runner:
Expand All @@ -220,6 +223,7 @@ def set_config(
new_daft_config = old_daft_config.with_config_values(
merge_scan_tasks_min_size_bytes=merge_scan_tasks_min_size_bytes,
merge_scan_tasks_max_size_bytes=merge_scan_tasks_max_size_bytes,
broadcast_join_size_bytes_threshold=broadcast_join_size_bytes_threshold,
)

ctx.daft_config = new_daft_config
Expand Down
5 changes: 4 additions & 1 deletion daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1058,7 +1058,7 @@ class LogicalPlanBuilder:

@staticmethod
def in_memory_scan(
partition_key: str, cache_entry: PartitionCacheEntry, schema: PySchema, num_partitions: int
partition_key: str, cache_entry: PartitionCacheEntry, schema: PySchema, num_partitions: int, size_bytes: int
) -> LogicalPlanBuilder: ...
@staticmethod
def table_scan_with_scan_operator(scan_operator: ScanOperatorHandle) -> LogicalPlanBuilder: ...
Expand Down Expand Up @@ -1102,11 +1102,14 @@ class PyDaftConfig:
self,
merge_scan_tasks_min_size_bytes: int | None = None,
merge_scan_tasks_max_size_bytes: int | None = None,
broadcast_join_size_bytes_threshold: int | None = None,
) -> PyDaftConfig: ...
@property
def merge_scan_tasks_min_size_bytes(self): ...
@property
def merge_scan_tasks_max_size_bytes(self): ...
@property
def broadcast_join_size_bytes_threshold(self): ...

def build_type() -> str: ...
def version() -> str: ...
Expand Down
30 changes: 25 additions & 5 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,16 @@ def _builder(self) -> LogicalPlanBuilder:
return self.__builder
else:
num_partitions = self._result_cache.num_partitions()
size_bytes = self._result_cache.size_bytes()
# Partition set should always be set on cache entry.
assert num_partitions is not None, "Partition set should always be set on cache entry"
assert (
num_partitions is not None and size_bytes is not None
), "Partition set should always be set on cache entry"
return self.__builder.from_in_memory_scan(
self._result_cache, self.__builder.schema(), num_partitions=num_partitions
self._result_cache,
self.__builder.schema(),
num_partitions=num_partitions,
size_bytes=size_bytes,
)

def _get_current_builder(self) -> LogicalPlanBuilder:
Expand Down Expand Up @@ -273,7 +279,11 @@ def _from_tables(cls, *parts: MicroPartition) -> "DataFrame":

context = get_context()
cache_entry = context.runner().put_partition_set_into_cache(result_pset)
builder = LogicalPlanBuilder.from_in_memory_scan(cache_entry, parts[0].schema(), result_pset.num_partitions())
size_bytes = result_pset.size_bytes()
assert size_bytes is not None, "In-memory data should always have non-None size in bytes"
builder = LogicalPlanBuilder.from_in_memory_scan(
cache_entry, parts[0].schema(), result_pset.num_partitions(), size_bytes
)
return cls(builder)

###
Expand Down Expand Up @@ -1233,8 +1243,13 @@ def _from_ray_dataset(cls, ds: "RayDataset") -> "DataFrame":

partition_set, schema = ray_runner_io.partition_set_from_ray_dataset(ds)
cache_entry = context.runner().put_partition_set_into_cache(partition_set)
size_bytes = partition_set.size_bytes()
assert size_bytes is not None, "In-memory data should always have non-None size in bytes"
builder = LogicalPlanBuilder.from_in_memory_scan(
cache_entry, schema=schema, num_partitions=partition_set.num_partitions()
cache_entry,
schema=schema,
num_partitions=partition_set.num_partitions(),
size_bytes=size_bytes,
)
return cls(builder)

Expand Down Expand Up @@ -1300,8 +1315,13 @@ def _from_dask_dataframe(cls, ddf: "dask.DataFrame") -> "DataFrame":

partition_set, schema = ray_runner_io.partition_set_from_dask_dataframe(ddf)
cache_entry = context.runner().put_partition_set_into_cache(partition_set)
size_bytes = partition_set.size_bytes()
assert size_bytes is not None, "In-memory data should always have non-None size in bytes"
builder = LogicalPlanBuilder.from_in_memory_scan(
cache_entry, schema=schema, num_partitions=partition_set.num_partitions()
cache_entry,
schema=schema,
num_partitions=partition_set.num_partitions(),
size_bytes=size_bytes,
)
return cls(builder)

Expand Down
13 changes: 12 additions & 1 deletion daft/execution/execution_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,12 +629,23 @@ class Join(SingleOutputInstruction):
left_on: ExpressionsProjection
right_on: ExpressionsProjection
how: JoinType
is_swapped: bool

def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
return self._join(inputs)

def _join(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
[left, right] = inputs
# All inputs except for the last are the left side of the join, in order to support left-broadcasted joins.
*lefts, right = inputs
clarkzinzow marked this conversation as resolved.
Show resolved Hide resolved
if len(lefts) > 1:
# NOTE: MicroPartition concats don't concatenate the underlying column arrays, since MicroPartitions are chunked.
left = MicroPartition.concat(lefts)
else:
left = lefts[0]
if self.is_swapped:
# Swap left/right back.
# We don't need to swap left_on and right_on since those were never swapped in the first place.
left, right = right, left
result = left.join(
right,
left_on=self.left_on,
Expand Down
130 changes: 129 additions & 1 deletion daft/execution/physical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@
)


def join(
def hash_join(
left_plan: InProgressPhysicalPlan[PartitionT],
right_plan: InProgressPhysicalPlan[PartitionT],
left_on: ExpressionsProjection,
Expand Down Expand Up @@ -233,6 +233,7 @@
left_on=left_on,
right_on=right_on,
how=how,
is_swapped=False,
)
)
yield join_step
Expand Down Expand Up @@ -271,6 +272,133 @@
return


def _create_join_step(
broadcaster_parts: deque[SingleOutputPartitionTask[PartitionT]],
receiver_part: SingleOutputPartitionTask[PartitionT],
left_on: ExpressionsProjection,
right_on: ExpressionsProjection,
how: JoinType,
is_swapped: bool,
) -> PartitionTaskBuilder[PartitionT]:
# Calculate memory request for task.
broadcaster_size_bytes_ = 0
broadcaster_partitions = []
broadcaster_partition_metadatas = []
null_count = 0
for next_broadcaster in broadcaster_parts:
next_broadcaster_partition_metadata = next_broadcaster.partition_metadata()
if next_broadcaster_partition_metadata is None or next_broadcaster_partition_metadata.size_bytes is None:
null_count += 1

Check warning on line 291 in daft/execution/physical_plan.py

View check run for this annotation

Codecov / codecov/patch

daft/execution/physical_plan.py#L291

Added line #L291 was not covered by tests
else:
broadcaster_size_bytes_ += next_broadcaster_partition_metadata.size_bytes
broadcaster_partitions.append(next_broadcaster.partition())
broadcaster_partition_metadatas.append(next_broadcaster_partition_metadata)
if null_count == len(broadcaster_parts):
broadcaster_size_bytes = None

Check warning on line 297 in daft/execution/physical_plan.py

View check run for this annotation

Codecov / codecov/patch

daft/execution/physical_plan.py#L297

Added line #L297 was not covered by tests
elif null_count > 0:
# Impute null size estimates with mean of non-null estimates.
broadcaster_size_bytes = broadcaster_size_bytes_ + math.ceil(

Check warning on line 300 in daft/execution/physical_plan.py

View check run for this annotation

Codecov / codecov/patch

daft/execution/physical_plan.py#L300

Added line #L300 was not covered by tests
null_count * broadcaster_size_bytes_ / (len(broadcaster_parts) - null_count)
)
else:
broadcaster_size_bytes = broadcaster_size_bytes_
receiver_size_bytes = receiver_part.partition_metadata().size_bytes
if broadcaster_size_bytes is None and receiver_size_bytes is None:
size_bytes = None

Check warning on line 307 in daft/execution/physical_plan.py

View check run for this annotation

Codecov / codecov/patch

daft/execution/physical_plan.py#L307

Added line #L307 was not covered by tests
elif broadcaster_size_bytes is None and receiver_size_bytes is not None:
# Use 1.25x the receiver side as the memory request, assuming that receiver side is ~4x larger than the broadcaster side.
size_bytes = int(1.25 * receiver_size_bytes)

Check warning on line 310 in daft/execution/physical_plan.py

View check run for this annotation

Codecov / codecov/patch

daft/execution/physical_plan.py#L310

Added line #L310 was not covered by tests
elif receiver_size_bytes is None and broadcaster_size_bytes is not None:
# Use 4x the broadcaster side as the memory request, assuming that receiver side is ~4x larger than the broadcaster side.
size_bytes = 4 * broadcaster_size_bytes

Check warning on line 313 in daft/execution/physical_plan.py

View check run for this annotation

Codecov / codecov/patch

daft/execution/physical_plan.py#L313

Added line #L313 was not covered by tests
elif broadcaster_size_bytes is not None and receiver_size_bytes is not None:
size_bytes = broadcaster_size_bytes + receiver_size_bytes

return PartitionTaskBuilder[PartitionT](
inputs=broadcaster_partitions + [receiver_part.partition()],
partial_metadatas=list(broadcaster_partition_metadatas + [receiver_part.partition_metadata()]),
resource_request=ResourceRequest(memory_bytes=size_bytes),
).add_instruction(
instruction=execution_step.Join(
left_on=left_on,
right_on=right_on,
how=how,
is_swapped=is_swapped,
)
)


def broadcast_join(
broadcaster_plan: InProgressPhysicalPlan[PartitionT],
receiver_plan: InProgressPhysicalPlan[PartitionT],
left_on: ExpressionsProjection,
right_on: ExpressionsProjection,
how: JoinType,
is_swapped: bool,
) -> InProgressPhysicalPlan[PartitionT]:
"""Broadcast join all partitions from the broadcaster child plan to each partition in the receiver child plan."""

# Materialize the steps from the broadcaster and receiver sources to get partitions.
# As the receiver-side materializations complete, emit new steps to join each broadcaster and receiver partition.
stage_id = next(stage_id_counter)
broadcaster_requests: deque[SingleOutputPartitionTask[PartitionT]] = deque()
broadcaster_parts: deque[SingleOutputPartitionTask[PartitionT]] = deque()

# First, fully materialize the broadcasting side (broadcaster side) of the join.
while True:
# Moved completed partition tasks in the broadcaster side of the join to the materialized partition set.
while broadcaster_requests and broadcaster_requests[0].done():
broadcaster_parts.append(broadcaster_requests.popleft())

# Execute single child step to pull in more broadcaster-side partitions.
try:
step = next(broadcaster_plan)
if isinstance(step, PartitionTaskBuilder):
step = step.finalize_partition_task_single_output(stage_id=stage_id)
broadcaster_requests.append(step)
yield step
except StopIteration:
if broadcaster_requests:
logger.debug(
"broadcast join blocked on completion of broadcasting side of join.\n broadcaster sources: %s",
broadcaster_requests,
)
yield None
else:
break

# Second, broadcast materialized partitions to receiver side of join, as it materializes.
receiver_requests: deque[SingleOutputPartitionTask[PartitionT]] = deque()

while True:
receiver_parts: deque[SingleOutputPartitionTask[PartitionT]] = deque()
# Moved completed partition tasks in the receiver side of the join to the materialized partition set.
while receiver_requests and receiver_requests[0].done():
receiver_parts.append(receiver_requests.popleft())

# Emit join steps for newly materialized partitions.
# Broadcast all broadcaster partitions to each new receiver partition that was materialized on this dispatch loop.
for receiver_part in receiver_parts:
yield _create_join_step(broadcaster_parts, receiver_part, left_on, right_on, how, is_swapped)

# Execute single child step to pull in more input partitions.
try:
step = next(receiver_plan)
if isinstance(step, PartitionTaskBuilder):
step = step.finalize_partition_task_single_output(stage_id=stage_id)
receiver_requests.append(step)
yield step
except StopIteration:
if receiver_requests:
logger.debug(
"broadcast join blocked on completion of receiver side of join.\n receiver sources: %s",
receiver_requests,
)
yield None
else:
return


def concat(
top_plan: InProgressPhysicalPlan[PartitionT], bottom_plan: InProgressPhysicalPlan[PartitionT]
) -> InProgressPhysicalPlan[PartitionT]:
Expand Down
24 changes: 22 additions & 2 deletions daft/execution/rust_physical_plan_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def reduce_merge(
return physical_plan.reduce(input, reduce_instruction)


def join(
def hash_join(
input: physical_plan.InProgressPhysicalPlan[PartitionT],
right: physical_plan.InProgressPhysicalPlan[PartitionT],
left_on: list[PyExpr],
Expand All @@ -197,7 +197,7 @@ def join(
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
left_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in left_on])
right_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in right_on])
return physical_plan.join(
return physical_plan.hash_join(
left_plan=input,
right_plan=right,
left_on=left_on_expr_proj,
Expand All @@ -206,6 +206,26 @@ def join(
)


def broadcast_join(
broadcaster: physical_plan.InProgressPhysicalPlan[PartitionT],
receiver: physical_plan.InProgressPhysicalPlan[PartitionT],
left_on: list[PyExpr],
right_on: list[PyExpr],
join_type: JoinType,
is_swapped: bool,
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
left_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in left_on])
right_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in right_on])
return physical_plan.broadcast_join(
broadcaster_plan=broadcaster,
receiver_plan=receiver,
left_on=left_on_expr_proj,
right_on=right_on_expr_proj,
how=join_type,
is_swapped=is_swapped,
)


def write_file(
input: physical_plan.InProgressPhysicalPlan[PartitionT],
file_format: FileFormat,
Expand Down
3 changes: 3 additions & 0 deletions daft/io/file_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,12 @@ def from_glob_path(path: str, io_config: Optional[IOConfig] = None) -> DataFrame
file_infos_table = MicroPartition._from_pytable(file_infos.to_table())
partition = LocalPartitionSet({0: file_infos_table})
cache_entry = context.runner().put_partition_set_into_cache(partition)
size_bytes = partition.size_bytes()
assert size_bytes is not None, "In-memory data should always have non-None size in bytes"
builder = LogicalPlanBuilder.from_in_memory_scan(
cache_entry,
schema=file_infos_table.schema(),
num_partitions=partition.num_partitions(),
size_bytes=size_bytes,
)
return DataFrame(builder)
6 changes: 4 additions & 2 deletions daft/logical/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,11 @@ def optimize(self) -> LogicalPlanBuilder:

@classmethod
def from_in_memory_scan(
cls, partition: PartitionCacheEntry, schema: Schema, num_partitions: int
cls, partition: PartitionCacheEntry, schema: Schema, num_partitions: int, size_bytes: int
) -> LogicalPlanBuilder:
builder = _LogicalPlanBuilder.in_memory_scan(partition.key, partition, schema._schema, num_partitions)
builder = _LogicalPlanBuilder.in_memory_scan(
partition.key, partition, schema._schema, num_partitions, size_bytes
)
return cls(builder)

@classmethod
Expand Down
7 changes: 7 additions & 0 deletions daft/runners/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@
def __len__(self) -> int:
raise NotImplementedError()

@abstractmethod
def size_bytes(self) -> int | None:
raise NotImplementedError()

Check warning on line 180 in daft/runners/partitioning.py

View check run for this annotation

Codecov / codecov/patch

daft/runners/partitioning.py#L180

Added line #L180 was not covered by tests

@abstractmethod
def num_partitions(self) -> int:
raise NotImplementedError()
Expand Down Expand Up @@ -208,6 +212,9 @@
def num_partitions(self) -> int | None:
return self.value.num_partitions() if self.value is not None else None

def size_bytes(self) -> int | None:
return self.value.size_bytes() if self.value is not None else None


class PartitionSetCache:
def __init__(self) -> None:
Expand Down
Loading
Loading