Skip to content

Commit

Permalink
Add broadcast join.
Browse files Browse the repository at this point in the history
  • Loading branch information
clarkzinzow committed Dec 7, 2023
1 parent 29cb3a3 commit 4fb0e12
Show file tree
Hide file tree
Showing 23 changed files with 425 additions and 47 deletions.
8 changes: 6 additions & 2 deletions daft/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,16 +193,19 @@ def set_runner_py(use_thread_pool: bool | None = None) -> DaftContext:
def set_config(
merge_scan_tasks_min_size_bytes: int | None = None,
merge_scan_tasks_max_size_bytes: int | None = None,
broadcast_join_size_bytes_threshold: int | None = None,
) -> DaftContext:
"""Globally sets various configuration parameters which control various aspects of Daft execution
Args:
merge_scan_tasks_min_size_bytes: Minimum size in bytes when merging ScanTasks when reading files from storage.
Increasing this value will make Daft perform more merging of files into a single partition before yielding,
which leads to bigger but fewer partitions. (Defaults to 64MB)
which leads to bigger but fewer partitions. (Defaults to 64 MiB)
merge_scan_tasks_max_size_bytes: Maximum size in bytes when merging ScanTasks when reading files from storage.
Increasing this value will increase the upper bound of the size of merged ScanTasks, which leads to bigger but
fewer partitions. (Defaults to 512MB)
fewer partitions. (Defaults to 512 MiB)
broadcast_join_size_bytes_threshold: If one side of a join is smaller than this threshold, a broadcast join will be used.
Default is 10 MiB.
"""
old_ctx = get_context()

Expand All @@ -211,6 +214,7 @@ def set_config(
new_daft_config = old_daft_config.with_config_values(
merge_scan_tasks_min_size_bytes=merge_scan_tasks_min_size_bytes,
merge_scan_tasks_max_size_bytes=merge_scan_tasks_max_size_bytes,
broadcast_join_size_bytes_threshold=broadcast_join_size_bytes_threshold,
)

new_ctx = dataclasses.replace(
Expand Down
5 changes: 4 additions & 1 deletion daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1026,7 +1026,7 @@ class LogicalPlanBuilder:

@staticmethod
def in_memory_scan(
partition_key: str, cache_entry: PartitionCacheEntry, schema: PySchema, num_partitions: int
partition_key: str, cache_entry: PartitionCacheEntry, schema: PySchema, num_partitions: int, size_bytes: int
) -> LogicalPlanBuilder: ...
@staticmethod
def table_scan_with_scan_operator(scan_operator: ScanOperatorHandle) -> LogicalPlanBuilder: ...
Expand Down Expand Up @@ -1070,11 +1070,14 @@ class PyDaftConfig:
self,
merge_scan_tasks_min_size_bytes: int | None = None,
merge_scan_tasks_max_size_bytes: int | None = None,
broadcast_join_size_bytes_threshold: int | None = None,
) -> PyDaftConfig: ...
@property
def merge_scan_tasks_min_size_bytes(self): ...
@property
def merge_scan_tasks_max_size_bytes(self): ...
@property
def broadcast_join_size_bytes_threshold(self): ...

def build_type() -> str: ...
def version() -> str: ...
Expand Down
30 changes: 25 additions & 5 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,16 @@ def _builder(self) -> LogicalPlanBuilder:
return self.__builder
else:
num_partitions = self._result_cache.num_partitions()
size_bytes = self._result_cache.size_bytes()
# Partition set should always be set on cache entry.
assert num_partitions is not None, "Partition set should always be set on cache entry"
assert (
num_partitions is not None and size_bytes is not None
), "Partition set should always be set on cache entry"
return self.__builder.from_in_memory_scan(
self._result_cache, self.__builder.schema(), num_partitions=num_partitions
self._result_cache,
self.__builder.schema(),
num_partitions=num_partitions,
size_bytes=size_bytes,
)

def _get_current_builder(self) -> LogicalPlanBuilder:
Expand Down Expand Up @@ -273,7 +279,11 @@ def _from_tables(cls, *parts: MicroPartition) -> "DataFrame":

context = get_context()
cache_entry = context.runner().put_partition_set_into_cache(result_pset)
builder = LogicalPlanBuilder.from_in_memory_scan(cache_entry, parts[0].schema(), result_pset.num_partitions())
size_bytes = result_pset.size_bytes()
assert size_bytes is not None, "In-memory data should always have non-None size in bytes"
builder = LogicalPlanBuilder.from_in_memory_scan(
cache_entry, parts[0].schema(), result_pset.num_partitions(), size_bytes
)
return cls(builder)

###
Expand Down Expand Up @@ -1233,8 +1243,13 @@ def _from_ray_dataset(cls, ds: "RayDataset") -> "DataFrame":

partition_set, schema = ray_runner_io.partition_set_from_ray_dataset(ds)
cache_entry = context.runner().put_partition_set_into_cache(partition_set)
size_bytes = partition_set.size_bytes()
assert size_bytes is not None, "In-memory data should always have non-None size in bytes"
builder = LogicalPlanBuilder.from_in_memory_scan(
cache_entry, schema=schema, num_partitions=partition_set.num_partitions()
cache_entry,
schema=schema,
num_partitions=partition_set.num_partitions(),
size_bytes=size_bytes,
)
return cls(builder)

Expand Down Expand Up @@ -1300,8 +1315,13 @@ def _from_dask_dataframe(cls, ddf: "dask.DataFrame") -> "DataFrame":

partition_set, schema = ray_runner_io.partition_set_from_dask_dataframe(ddf)
cache_entry = context.runner().put_partition_set_into_cache(partition_set)
size_bytes = partition_set.size_bytes()
assert size_bytes is not None, "In-memory data should always have non-None size in bytes"
builder = LogicalPlanBuilder.from_in_memory_scan(
cache_entry, schema=schema, num_partitions=partition_set.num_partitions()
cache_entry,
schema=schema,
num_partitions=partition_set.num_partitions(),
size_bytes=size_bytes,
)
return cls(builder)

Expand Down
11 changes: 10 additions & 1 deletion daft/execution/execution_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,12 +629,21 @@ class Join(SingleOutputInstruction):
left_on: ExpressionsProjection
right_on: ExpressionsProjection
how: JoinType
is_swapped: bool

def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
return self._join(inputs)

def _join(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
[left, right] = inputs
*lefts, right = inputs
if len(lefts) > 1:
left = MicroPartition.concat(lefts)
else:
left = lefts[0]
if self.is_swapped:
# Swap left/right back.
# We don't need to swap left_on and right_on since those were never swapped in the first place.
left, right = right, left
result = left.join(
right,
left_on=self.left_on,
Expand Down
128 changes: 127 additions & 1 deletion daft/execution/physical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def pipeline_instruction(
)


def join(
def hash_join(
left_plan: InProgressPhysicalPlan[PartitionT],
right_plan: InProgressPhysicalPlan[PartitionT],
left_on: ExpressionsProjection,
Expand Down Expand Up @@ -233,6 +233,7 @@ def join(
left_on=left_on,
right_on=right_on,
how=how,
is_swapped=False,
)
)
yield join_step
Expand Down Expand Up @@ -271,6 +272,131 @@ def join(
return


def broadcast_join(
broadcaster_plan: InProgressPhysicalPlan[PartitionT],
reciever_plan: InProgressPhysicalPlan[PartitionT],
left_on: ExpressionsProjection,
right_on: ExpressionsProjection,
how: JoinType,
is_swapped: bool,
) -> InProgressPhysicalPlan[PartitionT]:
"""Broadcast join all partitions from the broadcaster child plan to each partition in the receiver child plan."""

def _emit_join_step(
broadcaster_parts: deque[SingleOutputPartitionTask[PartitionT]],
receiver_part: SingleOutputPartitionTask[PartitionT],
) -> PartitionTaskBuilder[PartitionT]:
# Calculate memory request for task.
broadcaster_size_bytes_ = 0
broadcaster_partitions = []
broadcaster_partition_metadatas = []
null_count = 0
for next_broadcaster in broadcaster_parts:
next_broadcaster_partition_metadata = next_broadcaster.partition_metadata()
if next_broadcaster_partition_metadata is None or next_broadcaster_partition_metadata.size_bytes is None:
null_count += 1

Check warning on line 297 in daft/execution/physical_plan.py

View check run for this annotation

Codecov / codecov/patch

daft/execution/physical_plan.py#L297

Added line #L297 was not covered by tests
else:
broadcaster_size_bytes_ += next_broadcaster_partition_metadata.size_bytes
broadcaster_partitions.append(next_broadcaster.partition())
broadcaster_partition_metadatas.append(next_broadcaster_partition_metadata)
if null_count == len(broadcaster_parts):
broadcaster_size_bytes = None

Check warning on line 303 in daft/execution/physical_plan.py

View check run for this annotation

Codecov / codecov/patch

daft/execution/physical_plan.py#L303

Added line #L303 was not covered by tests
elif null_count > 0:
# Impute null size estimates with mean of non-null estimates.
broadcaster_size_bytes = broadcaster_size_bytes_ + math.ceil(

Check warning on line 306 in daft/execution/physical_plan.py

View check run for this annotation

Codecov / codecov/patch

daft/execution/physical_plan.py#L306

Added line #L306 was not covered by tests
null_count * broadcaster_size_bytes_ / (len(broadcaster_parts) - null_count)
)
else:
broadcaster_size_bytes = broadcaster_size_bytes_
receiver_size_bytes = receiver_part.partition_metadata().size_bytes
if broadcaster_size_bytes is None and receiver_size_bytes is None:
size_bytes = None

Check warning on line 313 in daft/execution/physical_plan.py

View check run for this annotation

Codecov / codecov/patch

daft/execution/physical_plan.py#L313

Added line #L313 was not covered by tests
elif broadcaster_size_bytes is None and receiver_size_bytes is not None:
# Use 1.25x the receiver side as the memory request, assuming that receiver side is ~4x larger than the broadcaster side.
size_bytes = int(1.25 * receiver_size_bytes)

Check warning on line 316 in daft/execution/physical_plan.py

View check run for this annotation

Codecov / codecov/patch

daft/execution/physical_plan.py#L316

Added line #L316 was not covered by tests
elif receiver_size_bytes is None and broadcaster_size_bytes is not None:
# Use 4x the broadcaster side as the memory request, assuming that receiver side is ~4x larger than the broadcaster side.
size_bytes = 4 * broadcaster_size_bytes

Check warning on line 319 in daft/execution/physical_plan.py

View check run for this annotation

Codecov / codecov/patch

daft/execution/physical_plan.py#L319

Added line #L319 was not covered by tests
elif broadcaster_size_bytes is not None and receiver_size_bytes is not None:
size_bytes = broadcaster_size_bytes + receiver_size_bytes

return PartitionTaskBuilder[PartitionT](
inputs=broadcaster_partitions + [receiver_part.partition()],
partial_metadatas=list(broadcaster_partition_metadatas + [receiver_part.partition_metadata()]),
resource_request=ResourceRequest(memory_bytes=size_bytes),
).add_instruction(
instruction=execution_step.Join(
left_on=left_on,
right_on=right_on,
how=how,
is_swapped=is_swapped,
)
)

# Materialize the steps from the broadcaster and receiver sources to get partitions.
# As the receiver-side materializations complete, emit new steps to join each broadcaster and receiver partition.
stage_id = next(stage_id_counter)
broadcaster_requests: deque[SingleOutputPartitionTask[PartitionT]] = deque()
broadcaster_parts: deque[SingleOutputPartitionTask[PartitionT]] = deque()

# First, fully materialize the broadcasting side (broadcaster side) of the join.
while True:
# Moved completed partition tasks in the broadcaster side of the join to the materialized partition set.
while broadcaster_requests and broadcaster_requests[0].done():
broadcaster_parts.append(broadcaster_requests.popleft())

# Execute single child step to pull in more broadcaster-side partitions.
try:
step = next(broadcaster_plan)
if isinstance(step, PartitionTaskBuilder):
step = step.finalize_partition_task_single_output(stage_id=stage_id)
broadcaster_requests.append(step)
yield step
except StopIteration:
if broadcaster_requests:
logger.debug(
"broadcast join blocked on completion of broadcasting side of join.\n broadcaster sources: %s",
broadcaster_requests,
)
yield None
else:
break

# Second, broadcast materialized partitions to receiver side of join, as it materializes.
receiver_requests: deque[SingleOutputPartitionTask[PartitionT]] = deque()
receiver_parts: deque[SingleOutputPartitionTask[PartitionT]] = deque()

while True:
# Moved completed partition tasks in the receiver side of the join to the materialized partition set.
while receiver_requests and receiver_requests[0].done():
receiver_parts.append(receiver_requests.popleft())

# Emit join steps for newly materialized partitions.
# Broadcast all broadcaster partitions to each new receiver partition that was materialized on this dispatch loop.
for receiver_part in receiver_parts:
yield _emit_join_step(broadcaster_parts, receiver_part)

# We always emit a join step for every newly materialized receiver-side partition, so clear the materialized partition queue.
receiver_parts.clear()

# Execute single child step to pull in more input partitions.
try:
step = next(reciever_plan)
if isinstance(step, PartitionTaskBuilder):
step = step.finalize_partition_task_single_output(stage_id=stage_id)
receiver_requests.append(step)
yield step
except StopIteration:
if receiver_requests:
logger.debug(
"broadcast join blocked on completion of receiver side of join.\n receiver sources: %s",
receiver_requests,
)
yield None
else:
return


def concat(
top_plan: InProgressPhysicalPlan[PartitionT], bottom_plan: InProgressPhysicalPlan[PartitionT]
) -> InProgressPhysicalPlan[PartitionT]:
Expand Down
24 changes: 22 additions & 2 deletions daft/execution/rust_physical_plan_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def reduce_merge(
return physical_plan.reduce(input, reduce_instruction)


def join(
def hash_join(
input: physical_plan.InProgressPhysicalPlan[PartitionT],
right: physical_plan.InProgressPhysicalPlan[PartitionT],
left_on: list[PyExpr],
Expand All @@ -197,7 +197,7 @@ def join(
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
left_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in left_on])
right_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in right_on])
return physical_plan.join(
return physical_plan.hash_join(
left_plan=input,
right_plan=right,
left_on=left_on_expr_proj,
Expand All @@ -206,6 +206,26 @@ def join(
)


def broadcast_join(
input: physical_plan.InProgressPhysicalPlan[PartitionT],
right: physical_plan.InProgressPhysicalPlan[PartitionT],
left_on: list[PyExpr],
right_on: list[PyExpr],
join_type: JoinType,
is_swapped: bool,
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
left_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in left_on])
right_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in right_on])
return physical_plan.broadcast_join(
broadcaster_plan=input,
reciever_plan=right,
left_on=left_on_expr_proj,
right_on=right_on_expr_proj,
how=join_type,
is_swapped=is_swapped,
)


def write_file(
input: physical_plan.InProgressPhysicalPlan[PartitionT],
file_format: FileFormat,
Expand Down
3 changes: 3 additions & 0 deletions daft/io/file_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,12 @@ def from_glob_path(path: str, io_config: Optional[IOConfig] = None) -> DataFrame
file_infos_table = MicroPartition._from_pytable(file_infos.to_table())
partition = LocalPartitionSet({0: file_infos_table})
cache_entry = context.runner().put_partition_set_into_cache(partition)
size_bytes = partition.size_bytes()
assert size_bytes is not None, "In-memory data should always have non-None size in bytes"
builder = LogicalPlanBuilder.from_in_memory_scan(
cache_entry,
schema=file_infos_table.schema(),
num_partitions=partition.num_partitions(),
size_bytes=size_bytes,
)
return DataFrame(builder)
6 changes: 4 additions & 2 deletions daft/logical/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,11 @@ def optimize(self) -> LogicalPlanBuilder:

@classmethod
def from_in_memory_scan(
cls, partition: PartitionCacheEntry, schema: Schema, num_partitions: int
cls, partition: PartitionCacheEntry, schema: Schema, num_partitions: int, size_bytes: int
) -> LogicalPlanBuilder:
builder = _LogicalPlanBuilder.in_memory_scan(partition.key, partition, schema._schema, num_partitions)
builder = _LogicalPlanBuilder.in_memory_scan(
partition.key, partition, schema._schema, num_partitions, size_bytes
)
return cls(builder)

@classmethod
Expand Down
7 changes: 7 additions & 0 deletions daft/runners/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ def has_partition(self, idx: PartID) -> bool:
def __len__(self) -> int:
raise NotImplementedError()

@abstractmethod
def size_bytes(self) -> int | None:
raise NotImplementedError()

Check warning on line 180 in daft/runners/partitioning.py

View check run for this annotation

Codecov / codecov/patch

daft/runners/partitioning.py#L180

Added line #L180 was not covered by tests

@abstractmethod
def num_partitions(self) -> int:
raise NotImplementedError()
Expand Down Expand Up @@ -208,6 +212,9 @@ def __setstate__(self, key):
def num_partitions(self) -> int | None:
return self.value.num_partitions() if self.value is not None else None

def size_bytes(self) -> int | None:
return self.value.size_bytes() if self.value is not None else None


class PartitionSetCache:
def __init__(self) -> None:
Expand Down
10 changes: 9 additions & 1 deletion daft/runners/pyrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,15 @@ def has_partition(self, idx: PartID) -> bool:
return idx in self._partitions

def __len__(self) -> int:
return sum([len(self._partitions[pid]) for pid in self._partitions])
return sum(len(partition) for partition in self._partitions.values())

def size_bytes(self) -> int | None:
size_bytes_ = [partition.size_bytes() for partition in self._partitions.values()]
size_bytes: list[int] = [size for size in size_bytes_ if size is not None]
if len(size_bytes) != len(size_bytes_):
return None

Check warning on line 73 in daft/runners/pyrunner.py

View check run for this annotation

Codecov / codecov/patch

daft/runners/pyrunner.py#L73

Added line #L73 was not covered by tests
else:
return sum(size_bytes)

def num_partitions(self) -> int:
return len(self._partitions)
Expand Down
Loading

0 comments on commit 4fb0e12

Please sign in to comment.