Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Emit warning if local shuffle buffer would cause spilling #48925

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 61 additions & 2 deletions python/ray/data/_internal/batcher.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import warnings
from typing import Optional

import ray
from ray.data._internal.arrow_block import ArrowBlockAccessor
from ray.data._internal.arrow_ops import transform_pyarrow
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
from ray.data._internal.execution.util import memory_string
from ray.data.block import Block, BlockAccessor
from ray.util import log_once

# pyarrow.Table.slice is slow when the table has many chunks
# so we combine chunks into a single one to make slice faster
Expand Down Expand Up @@ -214,6 +218,10 @@ def __init__(
self._batch_head = 0
self._done_adding = False

self._total_object_store_nbytes = _get_total_obj_store_mem_on_node()
self._total_rows_added = 0
self._total_nbytes_added = 0

def add(self, block: Block):
"""Add a block to the shuffle buffer.

Expand All @@ -222,9 +230,51 @@ def add(self, block: Block):
Args:
block: Block to add to the shuffle buffer.
"""
if BlockAccessor.for_block(block).num_rows() > 0:
# Because Arrow tables are memory mapped, blocks in the local shuffle buffer
# resides in object store memory and not local heap memory. So, if you specify a
# large buffer size and there isn't enough object store memory on the node, you
# encounter spilling.
Comment on lines +233 to +236
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is confusing actually:

  • There's no relation b/w shuffle buffer and the object storage
  • Produced block will likely get into the OS only once it's being yielded from operator (that's using the batcher)

if (
self._estimated_max_buffer_nbytes is not None
and self._estimated_max_buffer_nbytes > self._total_object_store_nbytes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, 1) this if statement can be put under if BlockAccessor.for_block(block).num_rows() > 0: and after after_block. because this will guarantee that _estimated_max_buffer_nbytes is not None
2) we can skip calculating _estimated_max_buffer_nbytes after it's calculated once.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bveeramani this value should be doubled:

  • One half is blocks before batching
  • Other half is new block produced

and log_once("shuffle_buffer_mem_warning")
):
warnings.warn(
"The node you're iterating on has "
f"{memory_string(self._total_object_store_nbytes)} object "
"store memory, but the shuffle buffer is estimated to use "
f"{memory_string(self._estimated_max_buffer_nbytes)}. If you don't "
f"decrease the shuffle buffer size from {self._buffer_min_size} rows, "
"you might encounter spilling."
Comment on lines +242 to +248
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.g.,

UserWarning: The node you're iterating on has 128.0MB object store memory, but the shuffle buffer is estimated to use 384.0MB. If you don't decrease the shuffle buffer size from 2 rows, you might encounter spilling.

)

block_accessor = BlockAccessor.for_block(block)
self._total_rows_added += block_accessor.num_rows()
self._total_nbytes_added += block_accessor.size_bytes()
if block_accessor.num_rows() > 0:
self._builder.add_block(block)
Comment on lines +251 to 255
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're not adding the block we'd be changing the counters


@property
def _average_row_nbytes(self) -> Optional[int]:
"""Return the average number of bytes per row added to the shuffle buffer."""
return (
self._total_nbytes_added // self._total_rows_added
if self._total_rows_added > 0
else None
)

@property
def _estimated_max_buffer_nbytes(self) -> Optional[int]:
"""Return the estimated maximum number of bytes in the shuffle buffer."""
if self._average_row_nbytes is None:
return None

return (
self._average_row_nbytes
* self._buffer_min_size
* SHUFFLE_BUFFER_COMPACTION_RATIO
Comment on lines +274 to +275
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's extract this to a common util

)

def done_adding(self) -> bool:
"""Indicate to the batcher that no more blocks will be added to the batcher.

Expand All @@ -251,7 +301,7 @@ def has_batch(self) -> bool:
return buffer_size >= self._batch_size

def _buffer_size(self) -> int:
"""Return shuffle buffer size."""
"""Return number of rows in shuffle buffer."""
buffer_size = self._builder.num_rows()
buffer_size += self._materialized_buffer_size()
return buffer_size
Expand Down Expand Up @@ -323,3 +373,12 @@ def next_batch(self) -> Block:
return BlockAccessor.for_block(self._shuffle_buffer).slice(
slice_start, self._batch_head
)


def _get_total_obj_store_mem_on_node() -> int:
node_id = ray.get_runtime_context().get_node_id()
total_resources_per_node = ray._private.state.total_resources_per_node()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC this API requires an RPC. Since this is only called once per iteration, I think the performance should be good enough

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC this API requires an RPC

Why don't we move this into DataContext and cache it there?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe put it in data/_internal/util.py, as it is a utility function

assert (
node_id in total_resources_per_node
), "Expected node '{node_id}' to be in resources: {total_resources_per_node}"
return total_resources_per_node[node_id]["object_store_memory"]
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ def validate_schema(self, schema: Optional[Union[type, "pyarrow.lib.Schema"]]):
for column in self._columns:
if column not in schema_names_set:
raise ValueError(
"The column '{}' does not exist in the "
"schema '{}'.".format(column, schema)
f"You specified the column '{column}', but there's no such "
"column in the dataset. The dataset has columns: "
f"{schema_names_set}"
)

@property
Expand Down
16 changes: 16 additions & 0 deletions python/ray/data/tests/test_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,22 @@ def test_local_shuffle_determinism(batch_size, local_shuffle_buffer_size):
assert all(batch_map[batch["id"][0]]["id"] == batch["id"])


def test_local_shuffle_buffer_warns_if_too_large(shutdown_only):
ray.shutdown()
ray.init(object_store_memory=128 * 1024 * 1024)

# Each row is 16 MiB * 8 = 128 MiB
ds = ray.data.range_tensor(2, shape=(16, 1024, 1024))

# Test that Ray Data emits a warning if the local shuffle buffer size would cause
# spilling.
with pytest.warns(UserWarning, match="shuffle buffer"):
# Each row is 128 MiB and the shuffle buffer size is 2 rows, so expect at least
# 256 MiB of memory usage > 128 MiB total on node.
batches = ds.iter_batches(batch_size=1, local_shuffle_buffer_size=2)
next(iter(batches))


if __name__ == "__main__":
import sys

Expand Down