Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: OOM investigations of map-only pipelines #3558

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions .github/workflows/run-cluster.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ jobs:
--runtime-env-json "$ray_env_var" \
-- python ${{ inputs.entrypoint_script }} ${{ inputs.entrypoint_args }}
- name: Download log files from ray cluster
if: always()
run: |
source .venv/bin/activate
ray rsync-down .github/assets/ray.yaml /tmp/ray/session_*/logs ray-daft-logs
Expand Down Expand Up @@ -152,6 +153,7 @@ jobs:
source .venv/bin/activate
ray down .github/assets/ray.yaml -y
- name: Upload log files
if: always()
uses: actions/upload-artifact@v4
with:
name: ray-daft-logs
Expand Down
71 changes: 71 additions & 0 deletions benchmarking/ooms/big_task_heap_usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# /// script
# dependencies = ['numpy', 'memray']
# ///

import argparse
import functools

import pyarrow as pa

import daft
from daft.io._generator import read_generator
from daft.table.table import Table

NUM_PARTITIONS = 8


@daft.udf(return_dtype=daft.DataType.binary())
def mock_inflate_data(data, inflation_factor):
return pa.array([x * inflation_factor for x in data.to_pylist()], type=pa.large_binary())


@daft.udf(return_dtype=daft.DataType.binary())
def mock_deflate_data(data, deflation_factor):
return [x[: int(len(x) / deflation_factor)] for x in data.to_pylist()]


def generate(num_rows_per_partition):
yield Table.from_pydict({"foo": [b"x" for _ in range(num_rows_per_partition)]})


def generator(
num_partitions: int,
num_rows_per_partition: int,
):
"""Generate data for all partitions."""
for i in range(num_partitions):
yield functools.partial(generate, num_rows_per_partition)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
"Runs a workload which is a simple map workload, but it will run 2 custom UDFs which first inflates the data, and then deflates it. "
"It starts with 1KB partitions, then runs inflation and subsequently deflation. We expect this to OOM if the heap memory usage exceeds "
"`MEM / N_CPUS` on a given worker node."
)
parser.add_argument("--num-partitions", type=int, default=8)
parser.add_argument("--num-rows-per-partition", type=int, default=1000)
parser.add_argument("--inflation-factor", type=int, default=100)
parser.add_argument("--deflation-factor", type=int, default=100)
args = parser.parse_args()

daft.context.set_runner_ray()

df = read_generator(
generator(args.num_partitions, args.num_rows_per_partition),
schema=daft.Schema._from_field_name_and_types([("foo", daft.DataType.binary())]),
)

df.collect()
print(df)

# Big memory explosion
df = df.with_column("foo", mock_inflate_data(df["foo"], args.inflation_factor))

# Big memory reduction
df = df.with_column("foo", mock_deflate_data(df["foo"], args.deflation_factor))

df.explain(True)

df.collect()
print(df)
43 changes: 24 additions & 19 deletions daft/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,38 +576,43 @@ def __hash__(self) -> int:
DataTypeLike = Union[DataType, type]


import threading

_EXT_TYPE_REGISTERED = False
_STATIC_DAFT_EXTENSION = None
_ext_type_lock = threading.Lock()


def _ensure_registered_super_ext_type():
global _EXT_TYPE_REGISTERED
global _STATIC_DAFT_EXTENSION
if not _EXT_TYPE_REGISTERED:

class DaftExtension(pa.ExtensionType):
def __init__(self, dtype, metadata=b""):
# attributes need to be set first before calling
# super init (as that calls serialize)
self._metadata = metadata
super().__init__(dtype, "daft.super_extension")
with _ext_type_lock:
if not _EXT_TYPE_REGISTERED:

class DaftExtension(pa.ExtensionType):
def __init__(self, dtype, metadata=b""):
# attributes need to be set first before calling
# super init (as that calls serialize)
self._metadata = metadata
super().__init__(dtype, "daft.super_extension")

def __reduce__(self):
return type(self).__arrow_ext_deserialize__, (self.storage_type, self.__arrow_ext_serialize__())
def __reduce__(self):
return type(self).__arrow_ext_deserialize__, (self.storage_type, self.__arrow_ext_serialize__())

def __arrow_ext_serialize__(self):
return self._metadata
def __arrow_ext_serialize__(self):
return self._metadata

@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
return cls(storage_type, serialized)
@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
return cls(storage_type, serialized)

_STATIC_DAFT_EXTENSION = DaftExtension
pa.register_extension_type(DaftExtension(pa.null()))
import atexit
_STATIC_DAFT_EXTENSION = DaftExtension
pa.register_extension_type(DaftExtension(pa.null()))
import atexit

atexit.register(lambda: pa.unregister_extension_type("daft.super_extension"))
_EXT_TYPE_REGISTERED = True
atexit.register(lambda: pa.unregister_extension_type("daft.super_extension"))
_EXT_TYPE_REGISTERED = True


def get_super_ext_type():
Expand Down
20 changes: 18 additions & 2 deletions daft/runners/ray_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ class EndTaskEvent(TaskEvent):

# End Unix timestamp
end: float
memory_stats: TaskMemoryStats


@dataclasses.dataclass(frozen=True)
class TaskMemoryStats:
peak_memory_allocated: int
total_memory_allocated: int
total_num_allocations: int


class _NodeInfo:
Expand Down Expand Up @@ -123,9 +131,15 @@ def mark_task_start(
)
)

def mark_task_end(self, execution_id: str, task_id: str, end: float):
def mark_task_end(
self,
execution_id: str,
task_id: str,
end: float,
memory_stats: TaskMemoryStats,
):
# Add an EndTaskEvent
self._task_events[execution_id].append(EndTaskEvent(task_id=task_id, end=end))
self._task_events[execution_id].append(EndTaskEvent(task_id=task_id, end=end, memory_stats=memory_stats))

def get_task_events(self, execution_id: str, idx: int) -> tuple[list[TaskEvent], int]:
events = self._task_events[execution_id]
Expand Down Expand Up @@ -177,11 +191,13 @@ def mark_task_end(
self,
task_id: str,
end: float,
memory_stats: TaskMemoryStats,
) -> None:
self.actor.mark_task_end.remote(
self.execution_id,
task_id,
end,
memory_stats,
)

def get_task_events(self, idx: int) -> tuple[list[TaskEvent], int]:
Expand Down
33 changes: 31 additions & 2 deletions daft/runners/ray_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import dataclasses
import json
import logging
import os
import pathlib
import time
from datetime import datetime
Expand Down Expand Up @@ -255,6 +256,11 @@ def _flush_task_metrics(self):
"ph": RunnerTracer.PHASE_ASYNC_END,
"pid": 1,
"tid": 2,
"args": {
"memray_peak_memory_allocated": task_event.memory_stats.peak_memory_allocated,
"memray_total_memory_allocated": task_event.memory_stats.total_memory_allocated,
"memray_total_num_allocations": task_event.memory_stats.total_num_allocations,
},
},
ts=end_ts,
)
Expand All @@ -272,6 +278,11 @@ def _flush_task_metrics(self):
"ph": RunnerTracer.PHASE_DURATION_END,
"pid": node_idx + RunnerTracer.NODE_PIDS_START,
"tid": worker_idx,
"args": {
"memray_peak_memory_allocated": task_event.memory_stats.peak_memory_allocated,
"memray_total_memory_allocated": task_event.memory_stats.total_memory_allocated,
"memray_total_num_allocations": task_event.memory_stats.total_num_allocations,
},
},
ts=end_ts,
)
Expand Down Expand Up @@ -658,6 +669,9 @@ def collect_ray_task_metrics(execution_id: str, task_id: str, stage_id: int, exe
if execution_config.enable_ray_tracing:
import time

import memray
from memray._memray import compute_statistics

runtime_context = ray.get_runtime_context()

metrics_actor = ray_metrics.get_metrics_actor(execution_id)
Expand All @@ -670,7 +684,22 @@ def collect_ray_task_metrics(execution_id: str, task_id: str, stage_id: int, exe
runtime_context.get_assigned_resources(),
runtime_context.get_task_id(),
)
yield
metrics_actor.mark_task_end(task_id, time.time())
tmpdir = "/tmp/ray/session_latest/logs/daft/task_memray_dumps"
os.makedirs(tmpdir, exist_ok=True)
memray_tmpfile = os.path.join(tmpdir, f"task-{task_id}.memray.bin")
try:
with memray.Tracker(memray_tmpfile, native_traces=True, follow_fork=True):
yield
finally:
stats = compute_statistics(memray_tmpfile)
metrics_actor.mark_task_end(
task_id,
time.time(),
ray_metrics.TaskMemoryStats(
peak_memory_allocated=stats.peak_memory_allocated,
total_memory_allocated=stats.total_memory_allocated,
total_num_allocations=stats.total_num_allocations,
),
)
else:
yield
Empty file added tests/memory/__init__.py
Empty file.
94 changes: 94 additions & 0 deletions tests/memory/test_udf_project.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import uuid

import pyarrow as pa
import pytest
from memray._memray import compute_statistics

import daft
from daft.execution.execution_step import ExpressionsProjection, Project
from tests.memory.utils import run_wrapper_build_partitions


def format_bytes(bytes_value):
"""Format bytes into human readable string with appropriate unit."""
for unit in ["B", "KB", "MB", "GB"]:
if bytes_value < 1024:
return f"{bytes_value:.2f} {unit}"
bytes_value /= 1024
return f"{bytes_value:.2f} GB"


@daft.udf(return_dtype=str)
def to_arrow_identity(s):
data = s.to_arrow()
return data


@daft.udf(return_dtype=str)
def to_pylist_identity(s):
data = s.to_pylist()
return data


@daft.udf(return_dtype=str, batch_size=128)
def to_arrow_identity_batched(s):
data = s.to_arrow()
return data


@daft.udf(return_dtype=str, batch_size=128)
def to_pylist_identity_batched(s):
data = s.to_pylist()
return data


@daft.udf(return_dtype=str, batch_size=128)
def to_pylist_identity_batched_arrow_return(s):
data = s.to_pylist()
return pa.array(data)


@pytest.mark.parametrize(
"udf",
[
to_arrow_identity,
to_pylist_identity,
to_arrow_identity_batched,
to_pylist_identity_batched,
to_pylist_identity_batched_arrow_return,
],
)
def test_short_string_identity_projection(udf):
instructions = [Project(ExpressionsProjection([udf(daft.col("a"))]))]
inputs = [{"a": [str(uuid.uuid4()) for _ in range(62500)]}]
_, memray_file = run_wrapper_build_partitions(inputs, instructions)
stats = compute_statistics(memray_file)

expected_peak_bytes = 100
assert stats.peak_memory_allocated < expected_peak_bytes, (
f"Peak memory ({format_bytes(stats.peak_memory_allocated)}) "
f"exceeded threshold ({format_bytes(expected_peak_bytes)})"
)


@pytest.mark.parametrize(
"udf",
[
to_arrow_identity,
to_pylist_identity,
to_arrow_identity_batched,
to_pylist_identity_batched,
to_pylist_identity_batched_arrow_return,
],
)
def test_long_string_identity_projection(udf):
instructions = [Project(ExpressionsProjection([udf(daft.col("a"))]))]
inputs = [{"a": [str(uuid.uuid4()) for _ in range(625000)]}]
_, memray_file = run_wrapper_build_partitions(inputs, instructions)
stats = compute_statistics(memray_file)

expected_peak_bytes = 100
assert stats.peak_memory_allocated < expected_peak_bytes, (
f"Peak memory ({format_bytes(stats.peak_memory_allocated)}) "
f"exceeded threshold ({format_bytes(expected_peak_bytes)})"
)
31 changes: 31 additions & 0 deletions tests/memory/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import logging
import os
import tempfile
import uuid
from unittest import mock

import memray

from daft.execution.execution_step import Instruction
from daft.runners.ray_runner import build_partitions
from daft.table import MicroPartition

logger = logging.getLogger(__name__)


def run_wrapper_build_partitions(
input_partitions: list[dict], instructions: list[Instruction]
) -> tuple[list[MicroPartition], str]:
inputs = [MicroPartition.from_pydict(p) for p in input_partitions]

logger.info("Input total size: %s", sum(i.size_bytes() for i in inputs))

tmpdir = tempfile.gettempdir()
memray_path = os.path.join(tmpdir, f"memray-{uuid.uuid4()}.bin")
with memray.Tracker(memray_path, native_traces=True, follow_fork=True):
results = build_partitions(
instructions,
[mock.Mock() for _ in range(len(input_partitions))],
*inputs,
)
return results[1:], memray_path
Loading