Skip to content

Commit

Permalink
[CHORE] enable refresh on tqdm total updates (#1654)
Browse files Browse the repository at this point in the history
* We currently were only using a background thread when updating the
done tasks. Now we also check if new tasks are enqueued to perform a
refresh.
  • Loading branch information
samster25 authored Nov 21, 2023
1 parent 06c2ccf commit aaf279e
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion daft/runners/progress_bar.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import os
import time
from typing import Any

from tqdm.auto import tqdm
Expand All @@ -13,6 +14,7 @@ def __init__(self, use_ray_tqdm: bool, show_tasks_bar: bool = False, disable: bo
self.use_ray_tqdm = use_ray_tqdm
self.show_tasks_bar = show_tasks_bar
self.tqdm_mod = tqdm
self._maxinterval = 5.0
self.pbars: dict[int, tqdm] = dict()
self.disable = (
disable
Expand All @@ -25,7 +27,12 @@ def _make_new_bar(self, stage_id: int, name: str):
self.pbars[stage_id] = self.tqdm_mod(total=1, desc=name, position=len(self.pbars))
else:
self.pbars[stage_id] = self.tqdm_mod(
total=1, desc=name, position=len(self.pbars), leave=False, mininterval=1.0
total=1,
desc=name,
position=len(self.pbars),
leave=False,
mininterval=1.0,
maxinterval=self._maxinterval,
)

def mark_task_start(self, step: PartitionTask[Any]) -> None:
Expand All @@ -46,6 +53,10 @@ def mark_task_start(self, step: PartitionTask[Any]) -> None:
else:
pb = self.pbars[stage_id]
pb.total += 1
if hasattr(pb, "last_print_t"):
dt = time.time() - pb.last_print_t
if dt >= self._maxinterval:
pb.refresh()

def mark_task_done(self, step: PartitionTask[Any]) -> None:
if self.disable:
Expand Down

0 comments on commit aaf279e

Please sign in to comment.