diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index ddc98521e2..4d816e1276 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -18,8 +18,8 @@ from daft.context import execution_config_ctx, get_context from daft.daft import PyTable as _PyTable from daft.dependencies import np -from daft.runners import tracer from daft.runners.progress_bar import ProgressBar +from daft.runners.tracer import tracer from daft.series import Series, item_to_series from daft.table import Table @@ -700,7 +700,7 @@ def place_in_queue(item): except Full: pass - with profiler(profile_filename), tracer.RunnerTracer(trace_filename) as runner_tracer: + with profiler(profile_filename), tracer(trace_filename) as runner_tracer: wave_count = 0 try: next_step = next(tasks) diff --git a/daft/runners/tracer.py b/daft/runners/tracer.py index 70c9aa716d..feeaf6f558 100644 --- a/daft/runners/tracer.py +++ b/daft/runners/tracer.py @@ -2,44 +2,53 @@ import contextlib import json +import os import time -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TextIO if TYPE_CHECKING: from daft import ResourceRequest -class RunnerTracer: - def __init__(self, filepath: str): - self._filepath = filepath +@contextlib.contextmanager +def tracer(filepath: str): + if int(os.environ.get("DAFT_RUNNER_TRACING", 0)) == 1: + with open(filepath, "w") as f: + # Initialize the JSON file + f.write("[") - def __enter__(self) -> RunnerTracer: - self._file = open(self._filepath, "w") - self._file.write("[") - self._start = time.time() - return self + # Yield the tracer + runner_tracer = RunnerTracer(f) + yield runner_tracer + + # Add the final touches to the file + f.write( + json.dumps({"name": "process_name", "ph": "M", "pid": 1, "args": {"name": "RayRunner dispatch loop"}}) + ) + f.write(",\n") + f.write(json.dumps({"name": "process_name", "ph": "M", "pid": 2, "args": {"name": "Ray Task Execution"}})) + f.write("\n]") + else: + runner_tracer = RunnerTracer(None) + yield runner_tracer - def __exit__(self, exc_type, exc_value, traceback): - self._file.write( - json.dumps({"name": "process_name", "ph": "M", "pid": 1, "args": {"name": "RayRunner dispatch loop"}}) - ) - self._file.write(",\n") - self._file.write( - json.dumps({"name": "process_name", "ph": "M", "pid": 2, "args": {"name": "Ray Task Execution"}}) - ) - self._file.write("\n]") - self._file.close() + +class RunnerTracer: + def __init__(self, file: TextIO | None): + self._file = file + self._start = time.time() def _write_event(self, event: dict[str, Any]): - self._file.write( - json.dumps( - { - **event, - "ts": int((time.time() - self._start) * 1000 * 1000), - } + if self._file is not None: + self._file.write( + json.dumps( + { + **event, + "ts": int((time.time() - self._start) * 1000 * 1000), + } + ) ) - ) - self._file.write(",\n") + self._file.write(",\n") @contextlib.contextmanager def dispatch_wave(self, wave_num: int): @@ -70,16 +79,6 @@ def metrics_updater(**kwargs): } ) - # def dispatch_wave_metrics(self, metrics: dict[str, int]): - # """Marks a counter event for various runner counters such as num cores, max inflight tasks etc""" - # self._write_event({ - # "name": "dispatch_metrics", - # "ph": "C", - # "pid": 1, - # "tid": 1, - # "args": metrics, - # }) - def count_inflight_tasks(self, count: int): self._write_event( { @@ -116,15 +115,6 @@ def dispatch_batching(self): } ) - # def count_dispatch_batch_size(self, dispatch_batch_size: int): - # self._write_event({ - # "name": "dispatch_batch_size", - # "ph": "C", - # "pid": 1, - # "tid": 1, - # "args": {"dispatch_batch_size": dispatch_batch_size}, - # }) - def mark_noop_task_start(self): """Marks the start of running a no-op task""" self._write_event( @@ -248,15 +238,6 @@ def awaiting(self): } ) - # def count_num_ready(self, num_ready: int): - # self._write_event({ - # "name": "awaiting", - # "ph": "C", - # "pid": 1, - # "tid": 1, - # "args": {"num_ready": num_ready}, - # }) - ### # Tracing each individual task as an Async Event ###