diff --git a/src/uberjob/_execution/run_function_on_graph.py b/src/uberjob/_execution/run_function_on_graph.py index 85b0a0b..8051773 100644 --- a/src/uberjob/_execution/run_function_on_graph.py +++ b/src/uberjob/_execution/run_function_on_graph.py @@ -17,18 +17,19 @@ import os import threading from contextlib import contextmanager +from typing import Dict, List, NamedTuple, Set from uberjob._execution.scheduler import create_queue -from uberjob._util import Slot -from uberjob._util.networkx_util import assert_acyclic, predecessor_count, source_nodes +from uberjob._util.networkx_util import assert_acyclic, predecessor_count +from uberjob.graph import Node class NodeError(Exception): - """An exception was raised during _execution of a node.""" + """An exception was raised during execution of a node.""" def __init__(self, node): super().__init__( - f"An exception was raised during _execution of the following node: {node!r}." + f"An exception was raised during execution of the following node: {node!r}." ) self.node = node @@ -81,18 +82,33 @@ def worker_pool(queue, process_item, worker_count): try: for _ in range(worker_count): workers.append(worker_thread(queue, process_item)) - yield queue.put + yield finally: for worker in workers: worker.join() -class LockDict: - def __init__(self, lock_count): - self.locks = [threading.Lock() for _ in range(lock_count)] +class PreparedNodes(NamedTuple): + source_nodes: List[Node] + single_parent_nodes: Set[Node] + remaining_pred_count_mapping: Dict[Node, int] - def __getitem__(self, key): - return self.locks[hash(key) % len(self.locks)] + +def prepare_nodes(graph) -> PreparedNodes: + source_nodes = [] + single_parent_nodes = set() + remaining_pred_count_mapping = {} + for node in graph: + count = predecessor_count(graph, node) + if count == 0: + source_nodes.append(node) + elif count == 1: + single_parent_nodes.add(node) + else: + remaining_pred_count_mapping[node] = count + return PreparedNodes( + source_nodes, single_parent_nodes, remaining_pred_count_mapping + ) def run_function_on_graph( @@ -101,14 +117,18 @@ def run_function_on_graph( assert_acyclic(graph) worker_count = coerce_worker_count(worker_count) max_errors = coerce_max_errors(max_errors) - failure_lock = threading.Lock() - node_locks = LockDict(8) - remaining_predecessor_counts = { - node: Slot(predecessor_count(graph, node)) for node in graph - } + + source_nodes, single_parent_nodes, remaining_pred_count_mapping = prepare_nodes( + graph + ) + remaining_pred_count_lock = threading.Lock() + stop = False first_node_error = None error_count = 0 + failure_lock = threading.Lock() + + queue = create_queue(graph, source_nodes, scheduler) def process_node(node): nonlocal stop @@ -128,18 +148,16 @@ def process_node(node): stop = True else: for successor in graph.successors(node): - remaining_predecessor_count = remaining_predecessor_counts[successor] - with node_locks[successor]: - remaining_predecessor_count.value -= 1 - should_submit = remaining_predecessor_count.value == 0 - if should_submit: - submit(successor) - - queue = create_queue(graph, scheduler) - with worker_pool(queue, process_node, worker_count) as submit: + if successor in single_parent_nodes: + queue.put(successor) + else: + with remaining_pred_count_lock: + remaining_pred_count_mapping[successor] -= 1 + if remaining_pred_count_mapping[successor] == 0: + queue.put(successor) + + with worker_pool(queue, process_node, worker_count): try: - for node in source_nodes(graph): - submit(node) queue.join() finally: stop = True diff --git a/src/uberjob/_execution/scheduler.py b/src/uberjob/_execution/scheduler.py index 71f29d3..ea0d21e 100644 --- a/src/uberjob/_execution/scheduler.py +++ b/src/uberjob/_execution/scheduler.py @@ -14,16 +14,27 @@ # limitations under the License. # import random +from collections import deque from functools import total_ordering -from heapq import heappop, heappush +from heapq import heapify, heappop, heappush from queue import Queue from uberjob._execution import greedy +def create_simple_queue(initial_items): + queue = Queue() + queue.queue = deque(initial_items) + queue.unfinished_tasks = len(queue.queue) + return queue + + class RandomQueue(Queue): - def _init(self, maxsize): - self.queue = [] + def __init__(self, initial_items): + super().__init__() + self.queue = list(initial_items) + random.shuffle(self.queue) + self.unfinished_tasks = len(self.queue) def _qsize(self): return len(self.queue) @@ -54,9 +65,11 @@ def __lt__(self, other): class PriorityQueue(Queue): - def __init__(self, priority, maxsize=0): - super().__init__(maxsize) - self.queue = [] + def __init__(self, initial_items, priority): + super().__init__() + self.queue = [KeyValuePair(priority(item), item) for item in initial_items] + heapify(self.queue) + self.unfinished_tasks = len(self.queue) self.priority = priority def _qsize(self): @@ -69,15 +82,18 @@ def _get(self): return heappop(self.queue).value -def create_queue(graph, scheduler): +def create_queue(graph, initial_items, scheduler): scheduler = scheduler or "default" if scheduler == "cheap": - return Queue() + return create_simple_queue(initial_items) if scheduler == "random": - return RandomQueue() + return RandomQueue(initial_items) if scheduler == "default": priority_mapping = greedy.get_priority_mapping(graph) # The priority mapping has priorities [0, n). # Setting the default priority to -1 gives the DONE sentinel highest priority. - return PriorityQueue(lambda node: priority_mapping.get(node, -1)) + return PriorityQueue(initial_items, lambda node: priority_mapping.get(node, -1)) raise ValueError(f"Invalid scheduler {scheduler!r}") + + +__all__ = ["create_queue"] diff --git a/src/uberjob/_registry.py b/src/uberjob/_registry.py index 389f383..cbd4909 100644 --- a/src/uberjob/_registry.py +++ b/src/uberjob/_registry.py @@ -65,8 +65,7 @@ def source(self, plan: Plan, value_store: ValueStore) -> Node: validation.assert_is_instance(plan, "plan", Plan) validation.assert_is_instance(value_store, "value_store", ValueStore) stack_frame = get_stack_frame() - node = plan.call(source) - node.stack_frame = stack_frame + node = plan._call(stack_frame, source) self.mapping[node] = RegistryValue( value_store, is_source=True, stack_frame=stack_frame ) diff --git a/src/uberjob/_run.py b/src/uberjob/_run.py index 3e652dd..2beb6db 100644 --- a/src/uberjob/_run.py +++ b/src/uberjob/_run.py @@ -182,6 +182,7 @@ def run( progress: Union[None, bool, Progress, Iterable[Progress]] = True, scheduler: Optional[str] = None, transform_physical: Optional[Callable[[Plan, Node], Tuple[Plan, Node]]] = None, + stale_check_max_workers: Optional[int] = None, ): """ Run a :class:`~uberjob.Plan`. @@ -211,6 +212,8 @@ def run( :param transform_physical: Optional transformation to be applied to the physical plan. It takes ``(plan, output_node)`` as positional arguments and returns ``(transformed_plan, redirected_output_node)``. + :param stale_check_max_workers: Optionally specify the maximum number of threads used for the stale check. + The default behavior is to use ``max_workers``. :return: The non-symbolic output corresponding to the symbolic output argument. """ assert_is_instance(plan, "plan", Plan) @@ -224,9 +227,16 @@ def run( assert_is_instance(max_workers, "max_workers", int, optional=True) if max_workers is not None and max_workers < 1: raise ValueError("max_workers must be at least 1.") + assert_is_instance( + stale_check_max_workers, "stale_check_max_workers", int, optional=True + ) + if stale_check_max_workers is not None and stale_check_max_workers < 1: + raise ValueError("stale_check_max_workers must be at least 1.") assert_is_instance(max_errors, "max_errors", int, optional=True) if max_errors is not None and max_errors < 0: raise ValueError("max_errors must be nonnegative.") + if stale_check_max_workers is None: + stale_check_max_workers = max_workers plan = get_mutable_plan(plan, inplace=False) @@ -246,7 +256,7 @@ def run( registry, output_node=output_node, progress_observer=progress_observer, - max_workers=max_workers, + max_workers=stale_check_max_workers, retry=retry, fresh_time=fresh_time, inplace=True, diff --git a/src/uberjob/_util/networkx_util.py b/src/uberjob/_util/networkx_util.py index 1d55307..860bd0f 100644 --- a/src/uberjob/_util/networkx_util.py +++ b/src/uberjob/_util/networkx_util.py @@ -38,12 +38,6 @@ def predecessor_count(graph, node) -> int: return len(graph.pred[node]) -def source_nodes(graph): - for node, predecessors in graph.pred.items(): - if not predecessors: - yield node - - def is_source_node(graph, node) -> bool: return not graph.pred[node] diff --git a/src/uberjob/progress/_console_progress_observer.py b/src/uberjob/progress/_console_progress_observer.py index 1d24c6e..b2040e5 100644 --- a/src/uberjob/progress/_console_progress_observer.py +++ b/src/uberjob/progress/_console_progress_observer.py @@ -81,7 +81,7 @@ def _render(self, state, new_exception_index, exception_tuples, elapsed): print_ = partial(print, file=buffer) _print_header(print_, elapsed) for section in ("stale", "run"): - scope_mapping = state.section_scope_mapping.get(section) + scope_mapping = state.get(section) if scope_mapping: is_done = all( s.completed + s.failed == s.total diff --git a/src/uberjob/progress/_html_progress_observer.py b/src/uberjob/progress/_html_progress_observer.py index 3f0e208..e8517ec 100644 --- a/src/uberjob/progress/_html_progress_observer.py +++ b/src/uberjob/progress/_html_progress_observer.py @@ -130,7 +130,7 @@ def _render(self, state, new_exception_index, exception_tuples, elapsed): ("stale", "Determining stale value stores"), ("run", "Running graph"), ): - scope_mapping = state.section_scope_mapping.get(section) + scope_mapping = state.get(section) if scope_mapping: lines.append('

{}

'.format(html.escape(title))) for scope, scope_state in sorted_scope_items(scope_mapping): diff --git a/src/uberjob/progress/_ipython_progress_observer.py b/src/uberjob/progress/_ipython_progress_observer.py index 98f1d88..beba698 100644 --- a/src/uberjob/progress/_ipython_progress_observer.py +++ b/src/uberjob/progress/_ipython_progress_observer.py @@ -68,7 +68,7 @@ def _render(self, state, new_exception_index, exception_tuples, elapsed): ("stale", "Determining stale value stores"), ("run", "Running graph"), ): - scope_mapping = state.section_scope_mapping.get(section) + scope_mapping = state.get(section) if scope_mapping: title_widget = self._get( "section", section, "title", default=widgets.HTML diff --git a/src/uberjob/progress/_simple_progress_observer.py b/src/uberjob/progress/_simple_progress_observer.py index 447ef67..af659e9 100644 --- a/src/uberjob/progress/_simple_progress_observer.py +++ b/src/uberjob/progress/_simple_progress_observer.py @@ -17,8 +17,6 @@ import threading import time from abc import ABC, abstractmethod -from collections import defaultdict -from contextlib import contextmanager from typing import Tuple from uberjob.progress._progress_observer import ProgressObserver @@ -73,14 +71,16 @@ def to_progress_string(self): class State: - def __init__(self): - self.section_scope_mapping = defaultdict(lambda: defaultdict(ScopeState)) + def __init__(self, start_time): + self.section_scope_mapping = {} self.running_count = 0 self._running_section_scopes = set() - self._prev_time = None + self._prev_time = start_time def increment_total(self, section, scope, amount: int): - self.section_scope_mapping[section][scope].total += amount + self.section_scope_mapping.setdefault(section, {}).setdefault( + scope, ScopeState() + ).total += amount def increment_running(self, section, scope): self.update_weighted_elapsed() @@ -110,14 +110,12 @@ def increment_failed(self, section, scope): def update_weighted_elapsed(self): t = time.time() - if self._prev_time: + if self.running_count: elapsed = t - self._prev_time - if self.running_count: - for section, scope in self._running_section_scopes: - scope_state = self.section_scope_mapping[section][scope] - scope_state.weighted_elapsed += ( - elapsed * scope_state.running / self.running_count - ) + multiplier = elapsed / self.running_count + for section, scope in self._running_section_scopes: + scope_state = self.section_scope_mapping[section][scope] + scope_state.weighted_elapsed += scope_state.running * multiplier self._prev_time = t @@ -134,15 +132,14 @@ def __init__( self._min_update_interval = min_update_interval self._max_update_interval = max_update_interval self._max_exception_count = max_exception_count - self._state = State() - self._running_scope_lookup = defaultdict(set) self._exception_tuples = [] self._new_exception_index = 0 self._lock = threading.Lock() self._stale = True self._done_event = threading.Event() self._thread = None - self._start_time = None + self._start_time = time.time() + self._state = State(self._start_time) self._last_render_time = None @abstractmethod @@ -154,7 +151,6 @@ def _output(self, value): pass def __enter__(self): - self._start_time = time.time() self._thread = threading.Thread(target=self._run_update_thread) self._thread.start() @@ -174,7 +170,7 @@ def _do_render(self): self._last_render_time = t self._state.update_weighted_elapsed() output_value = self._render( - self._state, + self._state.section_scope_mapping, self._new_exception_index, self._exception_tuples, t - self._start_time, @@ -196,26 +192,24 @@ def _run_update_thread(self): if output_value is not None: self._output(output_value) - @contextmanager - def _lock_and_make_stale(self): + def increment_total(self, *, section: str, scope: Tuple, amount: int): with self._lock: self._stale = True - yield - - def increment_total(self, *, section: str, scope: Tuple, amount: int): - with self._lock_and_make_stale(): self._state.increment_total(section, scope, amount) def increment_running(self, *, section: str, scope: Tuple): - with self._lock_and_make_stale(): + with self._lock: + self._stale = True self._state.increment_running(section, scope) def increment_completed(self, *, section: str, scope: Tuple): - with self._lock_and_make_stale(): + with self._lock: + self._stale = True self._state.increment_completed(section, scope) def increment_failed(self, *, section: str, scope: Tuple, exception: Exception): - with self._lock_and_make_stale(): + with self._lock: + self._stale = True self._state.increment_failed(section, scope) if len(self._exception_tuples) < self._max_exception_count: self._exception_tuples.append( diff --git a/tests/test_plan.py b/tests/test_plan.py index d33cec0..f3dd9de 100644 --- a/tests/test_plan.py +++ b/tests/test_plan.py @@ -222,18 +222,19 @@ def fizz(): def test_max_workers_and_max_errors_validation(self): plan = uberjob.Plan() - with self.assertRaises(TypeError): - uberjob.run(plan, max_workers="hello") - with self.assertRaises(TypeError): - uberjob.run(plan, max_workers=1.0) - with self.assertRaises(TypeError): - uberjob.run(plan, max_errors="hello") - with self.assertRaises(TypeError): - uberjob.run(plan, max_errors=1.0) - with self.assertRaises(ValueError): - uberjob.run(plan, max_workers=-1) - with self.assertRaises(ValueError): - uberjob.run(plan, max_workers=0) + + for arg_name in ["max_workers", "max_errors", "stale_check_max_workers"]: + for arg_value in ["hello", 1.0]: + with self.subTest(arg_name=arg_name, arg_value=arg_value): + with self.assertRaises(TypeError): + uberjob.run(plan, **{arg_name: arg_value}) + + for arg_name in ["max_workers", "stale_check_max_workers"]: + for arg_value in [-1, 0]: + with self.subTest(arg_name=arg_name, arg_value=arg_value): + with self.assertRaises(ValueError): + uberjob.run(plan, **{arg_name: arg_value}) + with self.assertRaises(ValueError): uberjob.run(plan, max_errors=-1) diff --git a/tests/test_registry.py b/tests/test_registry.py index 71f4743..64a5a95 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -61,6 +61,7 @@ def test_registry_simple(self): uberjob.run(p, registry=r) self.assertEqual(r[x].read_count, 0) self.assertEqual(r[x].write_count, 1) + uberjob.run(p, registry=r, stale_check_max_workers=1) def test_pruning(self): p = uberjob.Plan()