diff --git a/src/uberjob/_errors.py b/src/uberjob/_errors.py index 08b42e5..56bf78a 100644 --- a/src/uberjob/_errors.py +++ b/src/uberjob/_errors.py @@ -39,3 +39,12 @@ def __init__(self, call: Call): class NotTransformedError(Exception): """An expected transformation was not applied.""" + + +def create_chained_call_error(call: Call, exception: Exception) -> CallError: + call_error = CallError(call) + call_error.__cause__ = exception + return call_error + + +__all__ = ["CallError", "NotTransformedError", "create_chained_call_error"] diff --git a/src/uberjob/_execution/run_physical.py b/src/uberjob/_execution/run_physical.py index 1bcc98b..22b6c9b 100644 --- a/src/uberjob/_execution/run_physical.py +++ b/src/uberjob/_execution/run_physical.py @@ -16,12 +16,16 @@ """Functionality for executing a physical plan""" from typing import Any, Callable, Dict, NamedTuple, Optional +from uberjob._errors import create_chained_call_error from uberjob._execution.run_function_on_graph import run_function_on_graph +from uberjob._graph import get_full_scope from uberjob._plan import Plan from uberjob._transformations.pruning import prune_source_literals from uberjob._util import Slot from uberjob._util.retry import identity from uberjob.graph import Call, Graph, Literal, Node, get_argument_nodes +from uberjob.progress._null_progress_observer import NullProgressObserver +from uberjob.progress._progress_observer import ProgressObserver class BoundCall: @@ -74,42 +78,38 @@ class PrepRunPhysical(NamedTuple): plan: Plan -def _default_on_failed(node: Node, e: Exception): - pass - - def prep_run_physical( plan: Plan, *, inplace: bool, output_node: Optional[Node] = None, retry: Optional[Callable[[Callable], Callable]] = None, - on_started: Optional[Callable[[Node], None]] = None, - on_completed: Optional[Callable[[Node], None]] = None, - on_failed: Optional[Callable[[Node, Exception], None]] = None, + progress_observer: Optional[ProgressObserver] = None, ): bound_call_lookup, output_slot = _create_bound_call_lookup_and_output_slot( plan, output_node ) plan = prune_source_literals(plan, inplace=inplace) - on_started = on_started or identity - on_completed = on_completed or identity - on_failed = on_failed or _default_on_failed retry = retry or identity + progress_observer = progress_observer or NullProgressObserver() def process(node): if type(node) is Call: - on_started(node) + scope = get_full_scope(plan.graph, node) + progress_observer.increment_running(section="run", scope=scope) bound_call = bound_call_lookup[node] try: retry(bound_call.value.run)() - except Exception as e: - on_failed(node, e) + except Exception as exception: + progress_observer.increment_failed( + section="run", + scope=scope, + exception=create_chained_call_error(node, exception), + ) raise finally: bound_call.value = None - - on_completed(node) + progress_observer.increment_completed(section="run", scope=scope) return PrepRunPhysical(bound_call_lookup, output_slot, process, plan) @@ -123,18 +123,14 @@ def run_physical( max_workers: Optional[int] = None, max_errors: Optional[int] = 0, scheduler: Optional[str] = None, - on_started: Optional[Callable[[Node], None]] = None, - on_completed: Optional[Callable[[Node], None]] = None, - on_failed: Optional[Callable[[Node, Exception], None]] = None, + progress_observer: ProgressObserver, ) -> Any: _, output_slot, process, plan = prep_run_physical( plan, output_node=output_node, retry=retry, inplace=inplace, - on_started=on_started, - on_completed=on_completed, - on_failed=on_failed, + progress_observer=progress_observer, ) run_function_on_graph( diff --git a/src/uberjob/_graph.py b/src/uberjob/_graph.py new file mode 100644 index 0000000..eaa124b --- /dev/null +++ b/src/uberjob/_graph.py @@ -0,0 +1,26 @@ +# +# Copyright 2020 Two Sigma Open Source, LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Tuple + +from uberjob.graph import Graph, Node + + +def get_full_scope(graph: Graph, node: Node) -> Tuple: + node_data = graph.nodes[node] + return node_data["scope"] + node_data.get("implicit_scope", ()) + + +__all__ = ["get_full_scope"] diff --git a/src/uberjob/_run.py b/src/uberjob/_run.py index 2beb6db..617a14a 100644 --- a/src/uberjob/_run.py +++ b/src/uberjob/_run.py @@ -20,127 +20,32 @@ from uberjob._errors import CallError from uberjob._execution.run_function_on_graph import NodeError from uberjob._execution.run_physical import run_physical +from uberjob._graph import get_full_scope from uberjob._plan import Plan from uberjob._registry import Registry from uberjob._transformations import get_mutable_plan from uberjob._transformations.caching import plan_with_value_stores from uberjob._transformations.pruning import prune_plan -from uberjob._util import fully_qualified_name from uberjob._util.retry import create_retry from uberjob._util.validation import assert_is_callable, assert_is_instance -from uberjob.graph import Call, Graph, Node +from uberjob.graph import Call, Node from uberjob.progress import ( Progress, + ProgressObserver, composite_progress, default_progress, null_progress, ) -def _get_full_scope(graph: Graph, node: Node): - node_data = graph.nodes[node] - return node_data["scope"] + node_data.get("implicit_scope", ()) - - -def _get_scope_call_counts(plan: Plan): - return collections.Counter( - _get_full_scope(plan.graph, node) +def _update_run_totals(plan: Plan, progress_observer: ProgressObserver) -> None: + scope_counts = collections.Counter( + get_full_scope(plan.graph, node) for node in plan.graph.nodes() if type(node) is Call ) - - -def _prepare_plan_with_registry_and_progress( - plan, - registry, - *, - output_node, - progress_observer, - max_workers, - retry, - fresh_time, - inplace, -): - graph = plan.graph - - def get_stale_scope(node): - scope = _get_full_scope(graph, node) - value_store = registry.get(node) - if not value_store: - return scope - return (*scope, fully_qualified_name(value_store.__class__)) - - def on_started_stale_check(node): - progress_observer.increment_running( - section="stale", scope=get_stale_scope(node) - ) - - def on_completed_stale_check(node): - progress_observer.increment_completed( - section="stale", scope=get_stale_scope(node) - ) - - scope_counts = collections.Counter( - get_stale_scope(node) for node in plan.graph.nodes() if type(node) is Call - ) for scope, count in scope_counts.items(): - progress_observer.increment_total(section="stale", scope=scope, amount=count) - - return plan_with_value_stores( - plan, - registry, - output_node=output_node, - max_workers=max_workers, - retry=retry, - fresh_time=fresh_time, - inplace=inplace, - on_started=on_started_stale_check, - on_completed=on_completed_stale_check, - ) - - -def _run_physical_with_progress( - plan, - *, - output_node, - progress_observer, - max_workers, - max_errors, - retry, - scheduler, - inplace, -): - graph = plan.graph - - def on_started_run(node): - progress_observer.increment_running( - section="run", scope=_get_full_scope(graph, node) - ) - - def on_completed_run(node): - progress_observer.increment_completed( - section="run", scope=_get_full_scope(graph, node) - ) - - def on_failed_run(node, exception): - call_error = CallError(node) - call_error.__cause__ = exception - progress_observer.increment_failed( - section="run", scope=_get_full_scope(graph, node), exception=call_error - ) - - return run_physical( - plan, - output_node=output_node, - max_workers=max_workers, - max_errors=max_errors, - retry=retry, - scheduler=scheduler, - inplace=inplace, - on_started=on_started_run, - on_completed=on_completed_run, - on_failed=on_failed_run, - ) + progress_observer.increment_total(section="run", scope=scope, amount=count) def _coerce_progress( @@ -251,7 +156,7 @@ def run( try: with progress_observer: if registry: - plan, redirected_output_node = _prepare_plan_with_registry_and_progress( + plan, redirected_output_node = plan_with_value_stores( plan, registry, output_node=output_node, @@ -271,15 +176,12 @@ def run( plan, redirected_output_node ) - for scope, count in _get_scope_call_counts(plan).items(): - progress_observer.increment_total( - section="run", scope=scope, amount=count - ) + _update_run_totals(plan, progress_observer) if dry_run: return plan, redirected_output_node - return _run_physical_with_progress( + return run_physical( plan, output_node=redirected_output_node, progress_observer=progress_observer, diff --git a/src/uberjob/_transformations/caching.py b/src/uberjob/_transformations/caching.py index 8846ba9..17db9cc 100644 --- a/src/uberjob/_transformations/caching.py +++ b/src/uberjob/_transformations/caching.py @@ -13,16 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import collections import datetime as dt -from typing import Callable, Optional, Set, Tuple +from typing import Optional, Set, Tuple +from uberjob._errors import create_chained_call_error from uberjob._execution.run_function_on_graph import run_function_on_graph +from uberjob._graph import get_full_scope from uberjob._plan import Plan from uberjob._registry import Registry, RegistryValue from uberjob._transformations import get_mutable_plan from uberjob._transformations.pruning import prune_plan, prune_source_literals -from uberjob._util import Slot, safe_max +from uberjob._util import Slot, fully_qualified_name, safe_max from uberjob.graph import Call, Dependency, KeywordArg, Node, PositionalArg, get_scope +from uberjob.progress._progress_observer import ProgressObserver class BarrierType: @@ -39,6 +43,14 @@ def _to_naive_local_timezone(value: Optional[dt.datetime]) -> Optional[dt.dateti return value.astimezone().replace(tzinfo=None) if value and value.tzinfo else value +def _get_stale_scope(node: Node, plan: Plan, registry: Registry) -> Tuple: + scope = get_full_scope(plan.graph, node) + value_store = registry.get(node) + if value_store is None: + return scope + return (*scope, fully_qualified_name(value_store.__class__)) + + def _get_stale_nodes( plan: Plan, registry: Registry, @@ -46,8 +58,7 @@ def _get_stale_nodes( retry, max_workers: Optional[int] = None, fresh_time: Optional[dt.datetime] = None, - on_started: Optional[Callable[[Node], None]] = None, - on_completed: Optional[Callable[[Node], None]] = None + progress_observer: ProgressObserver, ) -> Set[Node]: plan = prune_source_literals( plan, inplace=False, predicate=lambda node: node not in registry @@ -89,13 +100,21 @@ def process(node): process_no_stale_ancestor(node) def process_with_callbacks(node): - if on_started is not None and type(node) is Call: - on_started(node) - try: + if type(node) is Call: + scope = _get_stale_scope(node, plan, registry) + progress_observer.increment_running(section="stale", scope=scope) + try: + process(node) + except Exception as exception: + progress_observer.increment_failed( + section="stale", + scope=scope, + exception=create_chained_call_error(node, exception), + ) + raise + progress_observer.increment_completed(section="stale", scope=scope) + else: process(node) - finally: - if on_completed is not None and type(node) is Call: - on_completed(node) run_function_on_graph( plan.graph, process_with_callbacks, worker_count=max_workers, scheduler="cheap" @@ -145,6 +164,18 @@ def nested_call(*args): return write_node, read_node +def _update_stale_totals( + plan: Plan, registry: Registry, progress_observer: ProgressObserver +) -> None: + scope_counts = collections.Counter( + _get_stale_scope(node, plan, registry) + for node in plan.graph.nodes() + if type(node) is Call + ) + for scope, count in scope_counts.items(): + progress_observer.increment_total(section="stale", scope=scope, amount=count) + + def plan_with_value_stores( plan: Plan, registry: Registry, @@ -154,9 +185,9 @@ def plan_with_value_stores( retry, fresh_time: Optional[dt.datetime] = None, inplace: bool, - on_started: Optional[Callable[[Node], None]] = None, - on_completed: Optional[Callable[[Node], None]] = None + progress_observer, ) -> Tuple[Plan, Optional[Node]]: + _update_stale_totals(plan, registry, progress_observer) plan = get_mutable_plan(plan, inplace=inplace) stale_nodes = _get_stale_nodes( plan, @@ -164,8 +195,7 @@ def plan_with_value_stores( max_workers=max_workers, retry=retry, fresh_time=fresh_time, - on_started=on_started, - on_completed=on_completed, + progress_observer=progress_observer, ) read_node_lookup = {} required_nodes = set()