Skip to content

Commit

Permalink
Merge pull request #6 from twosigma/performance-2021-03
Browse files Browse the repository at this point in the history
implemented various performance improvements
  • Loading branch information
timothy-shields authored Mar 8, 2021
2 parents cf424cd + 4a1ec8e commit c7318ea
Show file tree
Hide file tree
Showing 11 changed files with 120 additions and 87 deletions.
70 changes: 44 additions & 26 deletions src/uberjob/_execution/run_function_on_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,19 @@
import os
import threading
from contextlib import contextmanager
from typing import Dict, List, NamedTuple, Set

from uberjob._execution.scheduler import create_queue
from uberjob._util import Slot
from uberjob._util.networkx_util import assert_acyclic, predecessor_count, source_nodes
from uberjob._util.networkx_util import assert_acyclic, predecessor_count
from uberjob.graph import Node


class NodeError(Exception):
"""An exception was raised during _execution of a node."""
"""An exception was raised during execution of a node."""

def __init__(self, node):
super().__init__(
f"An exception was raised during _execution of the following node: {node!r}."
f"An exception was raised during execution of the following node: {node!r}."
)
self.node = node

Expand Down Expand Up @@ -81,18 +82,33 @@ def worker_pool(queue, process_item, worker_count):
try:
for _ in range(worker_count):
workers.append(worker_thread(queue, process_item))
yield queue.put
yield
finally:
for worker in workers:
worker.join()


class LockDict:
def __init__(self, lock_count):
self.locks = [threading.Lock() for _ in range(lock_count)]
class PreparedNodes(NamedTuple):
source_nodes: List[Node]
single_parent_nodes: Set[Node]
remaining_pred_count_mapping: Dict[Node, int]

def __getitem__(self, key):
return self.locks[hash(key) % len(self.locks)]

def prepare_nodes(graph) -> PreparedNodes:
source_nodes = []
single_parent_nodes = set()
remaining_pred_count_mapping = {}
for node in graph:
count = predecessor_count(graph, node)
if count == 0:
source_nodes.append(node)
elif count == 1:
single_parent_nodes.add(node)
else:
remaining_pred_count_mapping[node] = count
return PreparedNodes(
source_nodes, single_parent_nodes, remaining_pred_count_mapping
)


def run_function_on_graph(
Expand All @@ -101,14 +117,18 @@ def run_function_on_graph(
assert_acyclic(graph)
worker_count = coerce_worker_count(worker_count)
max_errors = coerce_max_errors(max_errors)
failure_lock = threading.Lock()
node_locks = LockDict(8)
remaining_predecessor_counts = {
node: Slot(predecessor_count(graph, node)) for node in graph
}

source_nodes, single_parent_nodes, remaining_pred_count_mapping = prepare_nodes(
graph
)
remaining_pred_count_lock = threading.Lock()

stop = False
first_node_error = None
error_count = 0
failure_lock = threading.Lock()

queue = create_queue(graph, source_nodes, scheduler)

def process_node(node):
nonlocal stop
Expand All @@ -128,18 +148,16 @@ def process_node(node):
stop = True
else:
for successor in graph.successors(node):
remaining_predecessor_count = remaining_predecessor_counts[successor]
with node_locks[successor]:
remaining_predecessor_count.value -= 1
should_submit = remaining_predecessor_count.value == 0
if should_submit:
submit(successor)

queue = create_queue(graph, scheduler)
with worker_pool(queue, process_node, worker_count) as submit:
if successor in single_parent_nodes:
queue.put(successor)
else:
with remaining_pred_count_lock:
remaining_pred_count_mapping[successor] -= 1
if remaining_pred_count_mapping[successor] == 0:
queue.put(successor)

with worker_pool(queue, process_node, worker_count):
try:
for node in source_nodes(graph):
submit(node)
queue.join()
finally:
stop = True
Expand Down
36 changes: 26 additions & 10 deletions src/uberjob/_execution/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,27 @@
# limitations under the License.
#
import random
from collections import deque
from functools import total_ordering
from heapq import heappop, heappush
from heapq import heapify, heappop, heappush
from queue import Queue

from uberjob._execution import greedy


def create_simple_queue(initial_items):
queue = Queue()
queue.queue = deque(initial_items)
queue.unfinished_tasks = len(queue.queue)
return queue


class RandomQueue(Queue):
def _init(self, maxsize):
self.queue = []
def __init__(self, initial_items):
super().__init__()
self.queue = list(initial_items)
random.shuffle(self.queue)
self.unfinished_tasks = len(self.queue)

def _qsize(self):
return len(self.queue)
Expand Down Expand Up @@ -54,9 +65,11 @@ def __lt__(self, other):


class PriorityQueue(Queue):
def __init__(self, priority, maxsize=0):
super().__init__(maxsize)
self.queue = []
def __init__(self, initial_items, priority):
super().__init__()
self.queue = [KeyValuePair(priority(item), item) for item in initial_items]
heapify(self.queue)
self.unfinished_tasks = len(self.queue)
self.priority = priority

def _qsize(self):
Expand All @@ -69,15 +82,18 @@ def _get(self):
return heappop(self.queue).value


def create_queue(graph, scheduler):
def create_queue(graph, initial_items, scheduler):
scheduler = scheduler or "default"
if scheduler == "cheap":
return Queue()
return create_simple_queue(initial_items)
if scheduler == "random":
return RandomQueue()
return RandomQueue(initial_items)
if scheduler == "default":
priority_mapping = greedy.get_priority_mapping(graph)
# The priority mapping has priorities [0, n).
# Setting the default priority to -1 gives the DONE sentinel highest priority.
return PriorityQueue(lambda node: priority_mapping.get(node, -1))
return PriorityQueue(initial_items, lambda node: priority_mapping.get(node, -1))
raise ValueError(f"Invalid scheduler {scheduler!r}")


__all__ = ["create_queue"]
3 changes: 1 addition & 2 deletions src/uberjob/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ def source(self, plan: Plan, value_store: ValueStore) -> Node:
validation.assert_is_instance(plan, "plan", Plan)
validation.assert_is_instance(value_store, "value_store", ValueStore)
stack_frame = get_stack_frame()
node = plan.call(source)
node.stack_frame = stack_frame
node = plan._call(stack_frame, source)
self.mapping[node] = RegistryValue(
value_store, is_source=True, stack_frame=stack_frame
)
Expand Down
12 changes: 11 additions & 1 deletion src/uberjob/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def run(
progress: Union[None, bool, Progress, Iterable[Progress]] = True,
scheduler: Optional[str] = None,
transform_physical: Optional[Callable[[Plan, Node], Tuple[Plan, Node]]] = None,
stale_check_max_workers: Optional[int] = None,
):
"""
Run a :class:`~uberjob.Plan`.
Expand Down Expand Up @@ -211,6 +212,8 @@ def run(
:param transform_physical: Optional transformation to be applied to the physical plan.
It takes ``(plan, output_node)`` as positional arguments and
returns ``(transformed_plan, redirected_output_node)``.
:param stale_check_max_workers: Optionally specify the maximum number of threads used for the stale check.
The default behavior is to use ``max_workers``.
:return: The non-symbolic output corresponding to the symbolic output argument.
"""
assert_is_instance(plan, "plan", Plan)
Expand All @@ -224,9 +227,16 @@ def run(
assert_is_instance(max_workers, "max_workers", int, optional=True)
if max_workers is not None and max_workers < 1:
raise ValueError("max_workers must be at least 1.")
assert_is_instance(
stale_check_max_workers, "stale_check_max_workers", int, optional=True
)
if stale_check_max_workers is not None and stale_check_max_workers < 1:
raise ValueError("stale_check_max_workers must be at least 1.")
assert_is_instance(max_errors, "max_errors", int, optional=True)
if max_errors is not None and max_errors < 0:
raise ValueError("max_errors must be nonnegative.")
if stale_check_max_workers is None:
stale_check_max_workers = max_workers

plan = get_mutable_plan(plan, inplace=False)

Expand All @@ -246,7 +256,7 @@ def run(
registry,
output_node=output_node,
progress_observer=progress_observer,
max_workers=max_workers,
max_workers=stale_check_max_workers,
retry=retry,
fresh_time=fresh_time,
inplace=True,
Expand Down
6 changes: 0 additions & 6 deletions src/uberjob/_util/networkx_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,6 @@ def predecessor_count(graph, node) -> int:
return len(graph.pred[node])


def source_nodes(graph):
for node, predecessors in graph.pred.items():
if not predecessors:
yield node


def is_source_node(graph, node) -> bool:
return not graph.pred[node]

Expand Down
2 changes: 1 addition & 1 deletion src/uberjob/progress/_console_progress_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _render(self, state, new_exception_index, exception_tuples, elapsed):
print_ = partial(print, file=buffer)
_print_header(print_, elapsed)
for section in ("stale", "run"):
scope_mapping = state.section_scope_mapping.get(section)
scope_mapping = state.get(section)
if scope_mapping:
is_done = all(
s.completed + s.failed == s.total
Expand Down
2 changes: 1 addition & 1 deletion src/uberjob/progress/_html_progress_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _render(self, state, new_exception_index, exception_tuples, elapsed):
("stale", "Determining stale value stores"),
("run", "Running graph"),
):
scope_mapping = state.section_scope_mapping.get(section)
scope_mapping = state.get(section)
if scope_mapping:
lines.append('<h3 class="mt-4">{}</h3>'.format(html.escape(title)))
for scope, scope_state in sorted_scope_items(scope_mapping):
Expand Down
2 changes: 1 addition & 1 deletion src/uberjob/progress/_ipython_progress_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _render(self, state, new_exception_index, exception_tuples, elapsed):
("stale", "Determining stale value stores"),
("run", "Running graph"),
):
scope_mapping = state.section_scope_mapping.get(section)
scope_mapping = state.get(section)
if scope_mapping:
title_widget = self._get(
"section", section, "title", default=widgets.HTML
Expand Down
48 changes: 21 additions & 27 deletions src/uberjob/progress/_simple_progress_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
import threading
import time
from abc import ABC, abstractmethod
from collections import defaultdict
from contextlib import contextmanager
from typing import Tuple

from uberjob.progress._progress_observer import ProgressObserver
Expand Down Expand Up @@ -73,14 +71,16 @@ def to_progress_string(self):


class State:
def __init__(self):
self.section_scope_mapping = defaultdict(lambda: defaultdict(ScopeState))
def __init__(self, start_time):
self.section_scope_mapping = {}
self.running_count = 0
self._running_section_scopes = set()
self._prev_time = None
self._prev_time = start_time

def increment_total(self, section, scope, amount: int):
self.section_scope_mapping[section][scope].total += amount
self.section_scope_mapping.setdefault(section, {}).setdefault(
scope, ScopeState()
).total += amount

def increment_running(self, section, scope):
self.update_weighted_elapsed()
Expand Down Expand Up @@ -110,14 +110,12 @@ def increment_failed(self, section, scope):

def update_weighted_elapsed(self):
t = time.time()
if self._prev_time:
if self.running_count:
elapsed = t - self._prev_time
if self.running_count:
for section, scope in self._running_section_scopes:
scope_state = self.section_scope_mapping[section][scope]
scope_state.weighted_elapsed += (
elapsed * scope_state.running / self.running_count
)
multiplier = elapsed / self.running_count
for section, scope in self._running_section_scopes:
scope_state = self.section_scope_mapping[section][scope]
scope_state.weighted_elapsed += scope_state.running * multiplier
self._prev_time = t


Expand All @@ -134,15 +132,14 @@ def __init__(
self._min_update_interval = min_update_interval
self._max_update_interval = max_update_interval
self._max_exception_count = max_exception_count
self._state = State()
self._running_scope_lookup = defaultdict(set)
self._exception_tuples = []
self._new_exception_index = 0
self._lock = threading.Lock()
self._stale = True
self._done_event = threading.Event()
self._thread = None
self._start_time = None
self._start_time = time.time()
self._state = State(self._start_time)
self._last_render_time = None

@abstractmethod
Expand All @@ -154,7 +151,6 @@ def _output(self, value):
pass

def __enter__(self):
self._start_time = time.time()
self._thread = threading.Thread(target=self._run_update_thread)
self._thread.start()

Expand All @@ -174,7 +170,7 @@ def _do_render(self):
self._last_render_time = t
self._state.update_weighted_elapsed()
output_value = self._render(
self._state,
self._state.section_scope_mapping,
self._new_exception_index,
self._exception_tuples,
t - self._start_time,
Expand All @@ -196,26 +192,24 @@ def _run_update_thread(self):
if output_value is not None:
self._output(output_value)

@contextmanager
def _lock_and_make_stale(self):
def increment_total(self, *, section: str, scope: Tuple, amount: int):
with self._lock:
self._stale = True
yield

def increment_total(self, *, section: str, scope: Tuple, amount: int):
with self._lock_and_make_stale():
self._state.increment_total(section, scope, amount)

def increment_running(self, *, section: str, scope: Tuple):
with self._lock_and_make_stale():
with self._lock:
self._stale = True
self._state.increment_running(section, scope)

def increment_completed(self, *, section: str, scope: Tuple):
with self._lock_and_make_stale():
with self._lock:
self._stale = True
self._state.increment_completed(section, scope)

def increment_failed(self, *, section: str, scope: Tuple, exception: Exception):
with self._lock_and_make_stale():
with self._lock:
self._stale = True
self._state.increment_failed(section, scope)
if len(self._exception_tuples) < self._max_exception_count:
self._exception_tuples.append(
Expand Down
Loading

0 comments on commit c7318ea

Please sign in to comment.