diff --git a/src/uberjob/_execution/run_function_on_graph.py b/src/uberjob/_execution/run_function_on_graph.py index 85b0a0b..8051773 100644 --- a/src/uberjob/_execution/run_function_on_graph.py +++ b/src/uberjob/_execution/run_function_on_graph.py @@ -17,18 +17,19 @@ import os import threading from contextlib import contextmanager +from typing import Dict, List, NamedTuple, Set from uberjob._execution.scheduler import create_queue -from uberjob._util import Slot -from uberjob._util.networkx_util import assert_acyclic, predecessor_count, source_nodes +from uberjob._util.networkx_util import assert_acyclic, predecessor_count +from uberjob.graph import Node class NodeError(Exception): - """An exception was raised during _execution of a node.""" + """An exception was raised during execution of a node.""" def __init__(self, node): super().__init__( - f"An exception was raised during _execution of the following node: {node!r}." + f"An exception was raised during execution of the following node: {node!r}." ) self.node = node @@ -81,18 +82,33 @@ def worker_pool(queue, process_item, worker_count): try: for _ in range(worker_count): workers.append(worker_thread(queue, process_item)) - yield queue.put + yield finally: for worker in workers: worker.join() -class LockDict: - def __init__(self, lock_count): - self.locks = [threading.Lock() for _ in range(lock_count)] +class PreparedNodes(NamedTuple): + source_nodes: List[Node] + single_parent_nodes: Set[Node] + remaining_pred_count_mapping: Dict[Node, int] - def __getitem__(self, key): - return self.locks[hash(key) % len(self.locks)] + +def prepare_nodes(graph) -> PreparedNodes: + source_nodes = [] + single_parent_nodes = set() + remaining_pred_count_mapping = {} + for node in graph: + count = predecessor_count(graph, node) + if count == 0: + source_nodes.append(node) + elif count == 1: + single_parent_nodes.add(node) + else: + remaining_pred_count_mapping[node] = count + return PreparedNodes( + source_nodes, single_parent_nodes, remaining_pred_count_mapping + ) def run_function_on_graph( @@ -101,14 +117,18 @@ def run_function_on_graph( assert_acyclic(graph) worker_count = coerce_worker_count(worker_count) max_errors = coerce_max_errors(max_errors) - failure_lock = threading.Lock() - node_locks = LockDict(8) - remaining_predecessor_counts = { - node: Slot(predecessor_count(graph, node)) for node in graph - } + + source_nodes, single_parent_nodes, remaining_pred_count_mapping = prepare_nodes( + graph + ) + remaining_pred_count_lock = threading.Lock() + stop = False first_node_error = None error_count = 0 + failure_lock = threading.Lock() + + queue = create_queue(graph, source_nodes, scheduler) def process_node(node): nonlocal stop @@ -128,18 +148,16 @@ def process_node(node): stop = True else: for successor in graph.successors(node): - remaining_predecessor_count = remaining_predecessor_counts[successor] - with node_locks[successor]: - remaining_predecessor_count.value -= 1 - should_submit = remaining_predecessor_count.value == 0 - if should_submit: - submit(successor) - - queue = create_queue(graph, scheduler) - with worker_pool(queue, process_node, worker_count) as submit: + if successor in single_parent_nodes: + queue.put(successor) + else: + with remaining_pred_count_lock: + remaining_pred_count_mapping[successor] -= 1 + if remaining_pred_count_mapping[successor] == 0: + queue.put(successor) + + with worker_pool(queue, process_node, worker_count): try: - for node in source_nodes(graph): - submit(node) queue.join() finally: stop = True diff --git a/src/uberjob/_execution/scheduler.py b/src/uberjob/_execution/scheduler.py index 71f29d3..ea0d21e 100644 --- a/src/uberjob/_execution/scheduler.py +++ b/src/uberjob/_execution/scheduler.py @@ -14,16 +14,27 @@ # limitations under the License. # import random +from collections import deque from functools import total_ordering -from heapq import heappop, heappush +from heapq import heapify, heappop, heappush from queue import Queue from uberjob._execution import greedy +def create_simple_queue(initial_items): + queue = Queue() + queue.queue = deque(initial_items) + queue.unfinished_tasks = len(queue.queue) + return queue + + class RandomQueue(Queue): - def _init(self, maxsize): - self.queue = [] + def __init__(self, initial_items): + super().__init__() + self.queue = list(initial_items) + random.shuffle(self.queue) + self.unfinished_tasks = len(self.queue) def _qsize(self): return len(self.queue) @@ -54,9 +65,11 @@ def __lt__(self, other): class PriorityQueue(Queue): - def __init__(self, priority, maxsize=0): - super().__init__(maxsize) - self.queue = [] + def __init__(self, initial_items, priority): + super().__init__() + self.queue = [KeyValuePair(priority(item), item) for item in initial_items] + heapify(self.queue) + self.unfinished_tasks = len(self.queue) self.priority = priority def _qsize(self): @@ -69,15 +82,18 @@ def _get(self): return heappop(self.queue).value -def create_queue(graph, scheduler): +def create_queue(graph, initial_items, scheduler): scheduler = scheduler or "default" if scheduler == "cheap": - return Queue() + return create_simple_queue(initial_items) if scheduler == "random": - return RandomQueue() + return RandomQueue(initial_items) if scheduler == "default": priority_mapping = greedy.get_priority_mapping(graph) # The priority mapping has priorities [0, n). # Setting the default priority to -1 gives the DONE sentinel highest priority. - return PriorityQueue(lambda node: priority_mapping.get(node, -1)) + return PriorityQueue(initial_items, lambda node: priority_mapping.get(node, -1)) raise ValueError(f"Invalid scheduler {scheduler!r}") + + +__all__ = ["create_queue"] diff --git a/src/uberjob/_registry.py b/src/uberjob/_registry.py index 389f383..cbd4909 100644 --- a/src/uberjob/_registry.py +++ b/src/uberjob/_registry.py @@ -65,8 +65,7 @@ def source(self, plan: Plan, value_store: ValueStore) -> Node: validation.assert_is_instance(plan, "plan", Plan) validation.assert_is_instance(value_store, "value_store", ValueStore) stack_frame = get_stack_frame() - node = plan.call(source) - node.stack_frame = stack_frame + node = plan._call(stack_frame, source) self.mapping[node] = RegistryValue( value_store, is_source=True, stack_frame=stack_frame ) diff --git a/src/uberjob/_run.py b/src/uberjob/_run.py index 3e652dd..2beb6db 100644 --- a/src/uberjob/_run.py +++ b/src/uberjob/_run.py @@ -182,6 +182,7 @@ def run( progress: Union[None, bool, Progress, Iterable[Progress]] = True, scheduler: Optional[str] = None, transform_physical: Optional[Callable[[Plan, Node], Tuple[Plan, Node]]] = None, + stale_check_max_workers: Optional[int] = None, ): """ Run a :class:`~uberjob.Plan`. @@ -211,6 +212,8 @@ def run( :param transform_physical: Optional transformation to be applied to the physical plan. It takes ``(plan, output_node)`` as positional arguments and returns ``(transformed_plan, redirected_output_node)``. + :param stale_check_max_workers: Optionally specify the maximum number of threads used for the stale check. + The default behavior is to use ``max_workers``. :return: The non-symbolic output corresponding to the symbolic output argument. """ assert_is_instance(plan, "plan", Plan) @@ -224,9 +227,16 @@ def run( assert_is_instance(max_workers, "max_workers", int, optional=True) if max_workers is not None and max_workers < 1: raise ValueError("max_workers must be at least 1.") + assert_is_instance( + stale_check_max_workers, "stale_check_max_workers", int, optional=True + ) + if stale_check_max_workers is not None and stale_check_max_workers < 1: + raise ValueError("stale_check_max_workers must be at least 1.") assert_is_instance(max_errors, "max_errors", int, optional=True) if max_errors is not None and max_errors < 0: raise ValueError("max_errors must be nonnegative.") + if stale_check_max_workers is None: + stale_check_max_workers = max_workers plan = get_mutable_plan(plan, inplace=False) @@ -246,7 +256,7 @@ def run( registry, output_node=output_node, progress_observer=progress_observer, - max_workers=max_workers, + max_workers=stale_check_max_workers, retry=retry, fresh_time=fresh_time, inplace=True, diff --git a/src/uberjob/_util/networkx_util.py b/src/uberjob/_util/networkx_util.py index 1d55307..860bd0f 100644 --- a/src/uberjob/_util/networkx_util.py +++ b/src/uberjob/_util/networkx_util.py @@ -38,12 +38,6 @@ def predecessor_count(graph, node) -> int: return len(graph.pred[node]) -def source_nodes(graph): - for node, predecessors in graph.pred.items(): - if not predecessors: - yield node - - def is_source_node(graph, node) -> bool: return not graph.pred[node] diff --git a/src/uberjob/progress/_console_progress_observer.py b/src/uberjob/progress/_console_progress_observer.py index 1d24c6e..b2040e5 100644 --- a/src/uberjob/progress/_console_progress_observer.py +++ b/src/uberjob/progress/_console_progress_observer.py @@ -81,7 +81,7 @@ def _render(self, state, new_exception_index, exception_tuples, elapsed): print_ = partial(print, file=buffer) _print_header(print_, elapsed) for section in ("stale", "run"): - scope_mapping = state.section_scope_mapping.get(section) + scope_mapping = state.get(section) if scope_mapping: is_done = all( s.completed + s.failed == s.total diff --git a/src/uberjob/progress/_html_progress_observer.py b/src/uberjob/progress/_html_progress_observer.py index 3f0e208..e8517ec 100644 --- a/src/uberjob/progress/_html_progress_observer.py +++ b/src/uberjob/progress/_html_progress_observer.py @@ -130,7 +130,7 @@ def _render(self, state, new_exception_index, exception_tuples, elapsed): ("stale", "Determining stale value stores"), ("run", "Running graph"), ): - scope_mapping = state.section_scope_mapping.get(section) + scope_mapping = state.get(section) if scope_mapping: lines.append('