diff --git a/daft/context.py b/daft/context.py index dbec115099..6182c16746 100644 --- a/daft/context.py +++ b/daft/context.py @@ -352,7 +352,7 @@ def set_execution_config( default_morsel_size: int | None = None, shuffle_algorithm: str | None = None, pre_shuffle_merge_threshold: int | None = None, - enable_ray_tracing: bool | None = None, + enable_ray_tracing: int | None = None, ) -> DaftContext: """Globally sets various configuration parameters which control various aspects of Daft execution. @@ -395,7 +395,8 @@ def set_execution_config( default_morsel_size: Default size of morsels used for the new local executor. Defaults to 131072 rows. shuffle_algorithm: The shuffle algorithm to use. Defaults to "map_reduce". Other options are "pre_shuffle_merge". pre_shuffle_merge_threshold: Memory threshold in bytes for pre-shuffle merge. Defaults to 1GB - enable_ray_tracing: Enable tracing for Ray. Accessible in `/tmp/ray/session_latest/logs/daft` after the run completes. Defaults to False. + enable_ray_tracing: Enable tracing for Ray. Accessible in `/tmp/ray/session_latest/logs/daft` after the run completes. Defaults to 0, but + can be set to 1 or 2 depending on the level of tracing desired. Levels 2 and above require `memray` to be installed. """ # Replace values in the DaftExecutionConfig with user-specified overrides ctx = get_context() diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index da292d2df1..1d1fbfeb29 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1736,7 +1736,7 @@ class PyDaftExecutionConfig: enable_aqe: bool | None = None, enable_native_executor: bool | None = None, default_morsel_size: int | None = None, - enable_ray_tracing: bool | None = None, + enable_ray_tracing: int | None = None, shuffle_algorithm: str | None = None, pre_shuffle_merge_threshold: int | None = None, ) -> PyDaftExecutionConfig: ... @@ -1783,7 +1783,7 @@ class PyDaftExecutionConfig: @property def pre_shuffle_merge_threshold(self) -> int: ... @property - def enable_ray_tracing(self) -> bool: ... + def enable_ray_tracing(self) -> int: ... class PyDaftPlanningConfig: @staticmethod diff --git a/daft/runners/ray_metrics.py b/daft/runners/ray_metrics.py index df542446c6..371a625387 100644 --- a/daft/runners/ray_metrics.py +++ b/daft/runners/ray_metrics.py @@ -51,6 +51,14 @@ class EndTaskEvent(TaskEvent): # End Unix timestamp end: float + memory_stats: TaskMemoryStats | None + + +@dataclasses.dataclass(frozen=True) +class TaskMemoryStats: + peak_memory_allocated: int + total_memory_allocated: int + total_num_allocations: int class _NodeInfo: @@ -123,9 +131,15 @@ def mark_task_start( ) ) - def mark_task_end(self, execution_id: str, task_id: str, end: float): + def mark_task_end( + self, + execution_id: str, + task_id: str, + end: float, + memory_stats: TaskMemoryStats | None, + ): # Add an EndTaskEvent - self._task_events[execution_id].append(EndTaskEvent(task_id=task_id, end=end)) + self._task_events[execution_id].append(EndTaskEvent(task_id=task_id, end=end, memory_stats=memory_stats)) def get_task_events(self, execution_id: str, idx: int) -> tuple[list[TaskEvent], int]: events = self._task_events[execution_id] @@ -177,11 +191,13 @@ def mark_task_end( self, task_id: str, end: float, + memory_stats: TaskMemoryStats | None, ) -> None: self.actor.mark_task_end.remote( self.execution_id, task_id, end, + memory_stats, ) def get_task_events(self, idx: int) -> tuple[list[TaskEvent], int]: diff --git a/daft/runners/ray_tracing.py b/daft/runners/ray_tracing.py index b200651a76..8ed42c3d69 100644 --- a/daft/runners/ray_tracing.py +++ b/daft/runners/ray_tracing.py @@ -10,6 +10,7 @@ import dataclasses import json import logging +import os import pathlib import time from datetime import datetime @@ -50,7 +51,7 @@ def ray_tracer(execution_id: str, daft_execution_config: PyDaftExecutionConfig) # Dump the RayRunner trace if we detect an active Ray session, otherwise we give up and do not write the trace ray_logs_location = get_log_location() filepath: pathlib.Path | None - if ray_logs_location.exists() and daft_execution_config.enable_ray_tracing: + if ray_logs_location.exists() and daft_execution_config.enable_ray_tracing > 0: trace_filename = ( f"trace_RayRunner.{execution_id}.{datetime.replace(datetime.now(), microsecond=0).isoformat()[:-3]}.json" ) @@ -255,6 +256,11 @@ def _flush_task_metrics(self): "ph": RunnerTracer.PHASE_ASYNC_END, "pid": 1, "tid": 2, + "args": { + "memray_peak_memory_allocated": task_event.memory_stats.peak_memory_allocated, + "memray_total_memory_allocated": task_event.memory_stats.total_memory_allocated, + "memray_total_num_allocations": task_event.memory_stats.total_num_allocations, + }, }, ts=end_ts, ) @@ -272,6 +278,11 @@ def _flush_task_metrics(self): "ph": RunnerTracer.PHASE_DURATION_END, "pid": node_idx + RunnerTracer.NODE_PIDS_START, "tid": worker_idx, + "args": { + "memray_peak_memory_allocated": task_event.memory_stats.peak_memory_allocated, + "memray_total_memory_allocated": task_event.memory_stats.total_memory_allocated, + "memray_total_num_allocations": task_event.memory_stats.total_num_allocations, + }, }, ts=end_ts, ) @@ -655,7 +666,9 @@ def __next__(self): @contextlib.contextmanager def collect_ray_task_metrics(execution_id: str, task_id: str, stage_id: int, execution_config: PyDaftExecutionConfig): """Context manager that will ping the metrics actor to record various execution metrics about a given task.""" - if execution_config.enable_ray_tracing: + if execution_config.enable_ray_tracing == 0: + yield + elif execution_config.enable_ray_tracing == 1: import time runtime_context = ray.get_runtime_context() @@ -670,7 +683,46 @@ def collect_ray_task_metrics(execution_id: str, task_id: str, stage_id: int, exe runtime_context.get_assigned_resources(), runtime_context.get_task_id(), ) - yield - metrics_actor.mark_task_end(task_id, time.time()) + try: + yield + finally: + metrics_actor.mark_task_end(task_id, time.time(), memory_stats=None) + elif execution_config.enable_ray_tracing == 2: + import time + + import memray + from memray._memray import compute_statistics + + tmpdir = "/tmp/ray/session_latest/logs/daft/task_memray_dumps" + os.makedirs(tmpdir, exist_ok=True) + memray_tmpfile = os.path.join(tmpdir, f"task-{task_id}.memray.bin") + + runtime_context = ray.get_runtime_context() + metrics_actor = ray_metrics.get_metrics_actor(execution_id) + metrics_actor.mark_task_start( + task_id, + time.time(), + runtime_context.get_node_id(), + runtime_context.get_worker_id(), + stage_id, + runtime_context.get_assigned_resources(), + runtime_context.get_task_id(), + ) + try: + with memray.Tracker(memray_tmpfile, native_traces=True, follow_fork=True): + yield + finally: + stats = compute_statistics(memray_tmpfile) + metrics_actor.mark_task_end( + task_id, + time.time(), + ray_metrics.TaskMemoryStats( + peak_memory_allocated=stats.peak_memory_allocated, + total_memory_allocated=stats.total_memory_allocated, + total_num_allocations=stats.total_num_allocations, + ), + ) else: - yield + raise RuntimeError( + f"Unrecognized value for $DAFT_ENABLE_RAY_TRACING. Expected a number from 0 to 2, but received: {execution_config.enable_ray_tracing}" + ) diff --git a/src/common/daft-config/src/lib.rs b/src/common/daft-config/src/lib.rs index 590fd5cf6c..df682e836f 100644 --- a/src/common/daft-config/src/lib.rs +++ b/src/common/daft-config/src/lib.rs @@ -52,7 +52,7 @@ pub struct DaftExecutionConfig { pub default_morsel_size: usize, pub shuffle_algorithm: String, pub pre_shuffle_merge_threshold: usize, - pub enable_ray_tracing: bool, + pub enable_ray_tracing: u32, } impl Default for DaftExecutionConfig { @@ -80,7 +80,7 @@ impl Default for DaftExecutionConfig { default_morsel_size: 128 * 1024, shuffle_algorithm: "map_reduce".to_string(), pre_shuffle_merge_threshold: 1024 * 1024 * 1024, // 1GB - enable_ray_tracing: false, + enable_ray_tracing: 0, } } } @@ -109,10 +109,12 @@ impl DaftExecutionConfig { cfg.enable_native_executor = true; } let ray_tracing_env_var_name = "DAFT_ENABLE_RAY_TRACING"; - if let Ok(val) = std::env::var(ray_tracing_env_var_name) - && matches!(val.trim().to_lowercase().as_str(), "1" | "true") - { - cfg.enable_ray_tracing = true; + if let Ok(val) = std::env::var(ray_tracing_env_var_name) { + if let Ok(val) = val.trim().parse::() { + cfg.enable_ray_tracing = val; + } else { + log::warn!("Invalid value for DAFT_ENABLE_RAY_TRACING. Expected a number from 0 to 2, but received: {}", val.trim()); + } } let shuffle_algorithm_env_var_name = "DAFT_SHUFFLE_ALGORITHM"; if let Ok(val) = std::env::var(shuffle_algorithm_env_var_name) { diff --git a/src/common/daft-config/src/python.rs b/src/common/daft-config/src/python.rs index 3228263b07..0cb3881a0f 100644 --- a/src/common/daft-config/src/python.rs +++ b/src/common/daft-config/src/python.rs @@ -98,7 +98,7 @@ impl PyDaftExecutionConfig { default_morsel_size: Option, shuffle_algorithm: Option<&str>, pre_shuffle_merge_threshold: Option, - enable_ray_tracing: Option, + enable_ray_tracing: Option, ) -> PyResult { let mut config = self.config.as_ref().clone(); @@ -290,7 +290,7 @@ impl PyDaftExecutionConfig { } #[getter] - fn enable_ray_tracing(&self) -> PyResult { + fn enable_ray_tracing(&self) -> PyResult { Ok(self.config.enable_ray_tracing) } }