diff --git a/daft/runners/ray_metrics.py b/daft/runners/ray_metrics.py index 66b1fbfd6b..650e509040 100644 --- a/daft/runners/ray_metrics.py +++ b/daft/runners/ray_metrics.py @@ -33,8 +33,7 @@ class TaskMetric: class ExecutionMetrics: """Holds the metrics for a given execution ID""" - task_start_info: dict[str, TaskMetric] = dataclasses.field(default_factory=lambda: {}) - task_ends: dict[str, float] = dataclasses.field(default_factory=lambda: {}) + task_metrics: dict[str, TaskMetric] = dataclasses.field(default_factory=lambda: {}) @ray.remote(num_cpus=0) @@ -53,7 +52,7 @@ def mark_task_start( self.execution_node_and_worker_ids[execution_id][node_id_trunc].add(worker_id_trunc) # Update task info - self.execution_metrics[execution_id].task_start_info[task_id] = TaskMetric( + self.execution_metrics[execution_id].task_metrics[task_id] = TaskMetric( task_id=task_id, stage_id=stage_id, start=start, @@ -63,24 +62,22 @@ def mark_task_start( ) def mark_task_end(self, execution_id: str, task_id: str, end: float): - self.execution_metrics[execution_id].task_ends[task_id] = end + self.execution_metrics[execution_id].task_metrics[task_id] = dataclasses.replace( + self.execution_metrics[execution_id].task_metrics[task_id], + end=end, + ) def collect_metrics(self, execution_id: str) -> tuple[list[TaskMetric], dict[str, set[str]]]: """Collect the metrics associated with this execution, cleaning up the memory used for this execution ID""" execution_metrics = self.execution_metrics[execution_id] - data = [ - dataclasses.replace( - execution_metrics.task_start_info[task_id], end=execution_metrics.task_ends.get(task_id) - ) - for task_id in execution_metrics.task_start_info - ] + task_metrics = list(execution_metrics.task_metrics.values()) node_data = self.execution_node_and_worker_ids[execution_id] # Clean up the stats for this execution del self.execution_metrics[execution_id] del self.execution_node_and_worker_ids[execution_id] - return data, node_data + return task_metrics, node_data @dataclasses.dataclass(frozen=True)