Skip to content

Commit

Permalink
Cleanup using defaultdict
Browse files Browse the repository at this point in the history
  • Loading branch information
Jay Chia committed Oct 27, 2024
1 parent 4b2a426 commit 92074e1
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 73 deletions.
57 changes: 29 additions & 28 deletions daft/runners/ray_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import dataclasses
import logging
import threading
from collections import defaultdict

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -31,39 +33,27 @@ class TaskMetric:
class ExecutionMetrics:
"""Holds the metrics for a given execution ID"""

daft_execution_id: str
task_start_info: dict[str, TaskMetric] = dataclasses.field(default_factory=lambda: {})
task_ends: dict[str, float] = dataclasses.field(default_factory=lambda: {})


@ray.remote(num_cpus=0)
class _MetricsActor:
def __init__(self):
self.execution_metrics: dict[str, ExecutionMetrics] = {}
self.execution_node_and_worker_ids: dict[str, dict[str, set[str]]] = {}

def _get_or_create_execution_metrics(self, execution_id: str) -> ExecutionMetrics:
if execution_id not in self.execution_metrics:
self.execution_metrics[execution_id] = ExecutionMetrics(daft_execution_id=execution_id)
return self.execution_metrics[execution_id]

def _get_or_create_execution_node_and_worker_ids(self, execution_id: str) -> dict[str, set[str]]:
if execution_id not in self.execution_node_and_worker_ids:
self.execution_node_and_worker_ids[execution_id] = {}
return self.execution_node_and_worker_ids[execution_id]
self.execution_metrics: dict[str, ExecutionMetrics] = defaultdict(lambda: ExecutionMetrics())
self.execution_node_and_worker_ids: dict[str, dict[str, set[str]]] = defaultdict(
lambda: defaultdict(lambda: set())
)

def mark_task_start(
self, execution_id: str, task_id: str, start: float, node_id: str, worker_id: str, stage_id: int
):
# Update node info
node_id_trunc, worker_id_trunc = node_id[:8], worker_id[:8]
node_info = self._get_or_create_execution_node_and_worker_ids(execution_id)
if node_id_trunc not in node_info:
node_info[node_id_trunc] = set()
node_info[node_id_trunc].add(worker_id_trunc)
self.execution_node_and_worker_ids[execution_id][node_id_trunc].add(worker_id_trunc)

# Update task info
self._get_or_create_execution_metrics(execution_id).task_start_info[task_id] = TaskMetric(
self.execution_metrics[execution_id].task_start_info[task_id] = TaskMetric(
task_id=task_id,
stage_id=stage_id,
start=start,
Expand All @@ -73,18 +63,18 @@ def mark_task_start(
)

def mark_task_end(self, execution_id: str, task_id: str, end: float):
self._get_or_create_execution_metrics(execution_id).task_ends[task_id] = end
self.execution_metrics[execution_id].task_ends[task_id] = end

def collect_metrics(self, execution_id: str) -> tuple[list[TaskMetric], dict[str, set[str]]]:
"""Collect the metrics associated with this execution, cleaning up the memory used for this execution ID"""
execution_metrics = self._get_or_create_execution_metrics(execution_id)
execution_metrics = self.execution_metrics[execution_id]
data = [
dataclasses.replace(
execution_metrics.task_start_info[task_id], end=execution_metrics.task_ends.get(task_id)
)
for task_id in execution_metrics.task_start_info
]
node_data = self._get_or_create_execution_node_and_worker_ids(execution_id)
node_data = self.execution_node_and_worker_ids[execution_id]

# Clean up the stats for this execution
del self.execution_metrics[execution_id]
Expand Down Expand Up @@ -130,12 +120,23 @@ def collect_metrics(self) -> tuple[list[TaskMetric], dict[str, set[str]]]:
return ray.get(self.actor.collect_metrics.remote(self.execution_id))


# Creating/getting an actor from multiple threads is not safe.
#
# This could be a problem because our Scheduler does multithreaded executions of plans if multiple
# plans are submitted at once.
#
# Pattern from Ray Data's _StatsActor:
# https://github.com/ray-project/ray/blob/0b1d0d8f01599796e1109060821583e270048b6e/python/ray/data/_internal/stats.py#L447-L449
_metrics_actor_lock: threading.RLock = threading.RLock()


def get_metrics_actor(execution_id: str) -> MetricsActorHandle:
"""Retrieves a handle to the Actor for a given job_id"""
actor = _MetricsActor.options( # type: ignore[attr-defined]
name="METRICS_ACTOR_NAME",
namespace=METRICS_ACTOR_NAMESPACE,
get_if_exists=True,
lifetime="detached",
).remote()
return MetricsActorHandle(execution_id, actor)
with _metrics_actor_lock:
actor = _MetricsActor.options( # type: ignore[attr-defined]
name="METRICS_ACTOR_NAME",
namespace=METRICS_ACTOR_NAMESPACE,
get_if_exists=True,
lifetime="detached",
).remote()
return MetricsActorHandle(execution_id, actor)
72 changes: 27 additions & 45 deletions daft/runners/ray_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@

@contextlib.contextmanager
def ray_tracer(execution_id: str):
metrics_actor = ray_metrics.get_metrics_actor(execution_id)

# Dump the RayRunner trace if we detect an active Ray session, otherwise we give up and do not write the trace
filepath: pathlib.Path | None
if pathlib.Path(DEFAULT_RAY_LOGS_LOCATION).exists():
Expand All @@ -72,26 +70,8 @@ def ray_tracer(execution_id: str):
runner_tracer = RunnerTracer(f)
yield runner_tracer

# Retrieve metrics from the metrics actor and perform some post-processing
task_metrics, node_metrics = metrics_actor.collect_metrics()
nodes_to_pid_mapping = {node_id: i + NODE_PIDS_START for i, node_id in enumerate(node_metrics)}
nodes_workers_to_tid_mapping = {
(node_id, worker_id): (pid, tid)
for node_id, pid in nodes_to_pid_mapping.items()
for tid, worker_id in enumerate(node_metrics[node_id])
}

# Write out collected metrics
for metric in task_metrics:
runner_tracer.write_task_metric(metric, nodes_workers_to_tid_mapping)
runner_tracer.write_stages()
runner_tracer._writer.write_footer(
[(pid, f"Node {node_id}") for node_id, pid in nodes_to_pid_mapping.items()],
[
(pid, tid, f"Worker {worker_id}")
for (_, worker_id), (pid, tid) in nodes_workers_to_tid_mapping.items()
],
)
metrics_actor = ray_metrics.get_metrics_actor(execution_id)
runner_tracer.finalize(metrics_actor)
else:
runner_tracer = RunnerTracer(None)
yield runner_tracer
Expand Down Expand Up @@ -194,7 +174,26 @@ def __init__(self, file: TextIO | None):
def _write_event(self, event: dict[str, Any], ts: int | None = None) -> int:
return self._writer.write_event(event, ts)

def write_stages(self):
def finalize(self, metrics_actor: ray_metrics.MetricsActorHandle) -> None:
# Retrieve metrics from the metrics actor and perform some post-processing
task_metrics, node_metrics = metrics_actor.collect_metrics()
nodes_to_pid_mapping = {node_id: i + NODE_PIDS_START for i, node_id in enumerate(node_metrics)}
nodes_workers_to_tid_mapping = {
(node_id, worker_id): (pid, tid)
for node_id, pid in nodes_to_pid_mapping.items()
for tid, worker_id in enumerate(node_metrics[node_id])
}

# Write out collected metrics
for metric in task_metrics:
self._write_task_metric(metric, nodes_workers_to_tid_mapping)
self._write_stages()
self._writer.write_footer(
[(pid, f"Node {node_id}") for node_id, pid in nodes_to_pid_mapping.items()],
[(pid, tid, f"Worker {worker_id}") for (_, worker_id), (pid, tid) in nodes_workers_to_tid_mapping.items()],
)

def _write_stages(self):
for stage_id in self._stage_start_end:
start_ts, end_ts = self._stage_start_end[stage_id]
self._write_event(
Expand Down Expand Up @@ -229,7 +228,7 @@ def write_stages(self):
ts=end_ts,
)

def write_task_metric(
def _write_task_metric(
self, metric: ray_metrics.TaskMetric, nodes_workers_to_pid_tid_mapping: dict[tuple[str, str], tuple[int, int]]
):
# Write to the Async view (will group by the stage ID)
Expand Down Expand Up @@ -292,6 +291,10 @@ def write_task_metric(
ts=int((metric.end - self._start) * 1000 * 1000),
)

###
# Tracing of scheduler dispatching
###

@contextlib.contextmanager
def dispatch_wave(self, wave_num: int):
self._write_event(
Expand Down Expand Up @@ -332,11 +335,6 @@ def count_inflight_tasks(self, count: int):
}
)

###
# Tracing the dispatch batching: when the runner is retrieving enough tasks
# from the physical plan in order to put them into a batch.
###

@contextlib.contextmanager
def dispatch_batching(self):
self._write_event(
Expand All @@ -357,10 +355,6 @@ def dispatch_batching(self):
}
)

###
# Tracing the dispatching of tasks
###

@contextlib.contextmanager
def dispatching(self):
self._write_event(
Expand All @@ -381,10 +375,6 @@ def dispatching(self):
}
)

###
# Tracing the waiting of tasks
###

@contextlib.contextmanager
def awaiting(self, waiting_for_num_results: int, wait_timeout_s: float | None):
name = f"awaiting {waiting_for_num_results} (timeout={wait_timeout_s})"
Expand All @@ -410,10 +400,6 @@ def awaiting(self, waiting_for_num_results: int, wait_timeout_s: float | None):
}
)

###
# Tracing the PhysicalPlan
###

@contextlib.contextmanager
def get_next_physical_plan(self):
self._write_event(
Expand Down Expand Up @@ -442,10 +428,6 @@ def update_args(**kwargs):
}
)

###
# Tracing each individual task as an Async Event
###

def task_created(self, task_id: str, stage_id: int, resource_request: ResourceRequest, instructions: str):
created_ts = self._write_event(
{
Expand Down

0 comments on commit 92074e1

Please sign in to comment.