From cc553548b2ff0c1c7496e022e93bc7dd23c834cb Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 29 Sep 2023 17:03:36 -0500 Subject: [PATCH 1/6] bump version to 2023.9.3 --- docs/source/changelog.rst | 34 ++++++++++++++++++++++++++++++++++ pyproject.toml | 2 +- 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index c15083aaa22..23c057ac266 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,39 @@ Changelog ========= +.. _v2023.9.3: + +2023.9.3 +-------- + +Released on September 29, 2023 + +Highlights +^^^^^^^^^^ + +Restore previous configuration override behavior +"""""""""""""""""""""""""""""""""""""""""""""""" +The 2023.9.2 release introduced an unintentional breaking change in +how configuration options are overriden in ``dask.config.get`` with +the ``override_with=`` keyword (see :issue:`10519`). +This release restores the previous behavior. + +See :pr:`10521` from `crusaderky`_ for details. + +Complex dtypes in Dask Array reductions +""""""""""""""""""""""""""""""""""""""" +This release includes improved support for using common reductions +in Dask Array (e.g. ``var``, ``std``, ``moment``) with complex dtypes. + +See :pr:`10009` from `wkrasnicki`_ for details. + +.. dropdown:: Additional changes + + - Bump ``actions/checkout`` from 4.0.0 to 4.1.0 (:pr:`10532`) + - Match ``pandas`` reverting ``apply`` deprecation (:pr:`10531`) `James Bourbeau`_ + - Update gpuCI ``RAPIDS_VER`` to ``23.12`` (:pr:`10526`) + - Temporarily skip failing tests with ``fsspec==2023.9.1`` (:pr:`10520`) `James Bourbeau`_ + .. _v2023.9.2: 2023.9.2 @@ -7024,3 +7057,4 @@ Other .. _`Alexander Clausen`: https://github.com/sk1p .. _`Swayam Patil`: https://github.com/Swish78 .. _`Johan Olsson`: https://github.com/johanols +.. _`wkrasnicki`: https://github.com/wkrasnicki diff --git a/pyproject.toml b/pyproject.toml index 6b861e5e171..c2903525307 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,7 @@ dataframe = [ "dask[array]", "pandas >= 1.3", ] -distributed = ["distributed == 2023.9.2"] +distributed = ["distributed == 2023.9.3"] diagnostics = [ "bokeh >= 2.4.2", "jinja2 >= 2.10.3", From 8ba447549ea1ffac3832bdf3f9e11c1b4ed118d8 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 4 Oct 2023 13:02:43 +0100 Subject: [PATCH 2/6] Tighten HighLevelGraph annotations (#10524) --- dask/array/core.py | 6 ++-- dask/bag/core.py | 4 +-- dask/bag/tests/test_bag.py | 3 +- dask/graph_manipulation.py | 4 +-- dask/highlevelgraph.py | 73 +++++++++++++++++++++++--------------- 5 files changed, 54 insertions(+), 36 deletions(-) diff --git a/dask/array/core.py b/dask/array/core.py index 1a1122d96df..283faefecaa 100644 --- a/dask/array/core.py +++ b/dask/array/core.py @@ -5803,10 +5803,10 @@ def __getitem__(self, index: Any) -> Array: keys = product(*(range(len(c)) for c in chunks)) - layer = {(name,) + key: tuple(new_keys[key].tolist()) for key in keys} + graph: Graph = {(name,) + key: tuple(new_keys[key].tolist()) for key in keys} - graph = HighLevelGraph.from_collections(name, layer, dependencies=[self._array]) - return Array(graph, name, chunks, meta=self._array) + hlg = HighLevelGraph.from_collections(name, graph, dependencies=[self._array]) + return Array(hlg, name, chunks, meta=self._array) def __eq__(self, other: Any) -> bool: if isinstance(other, BlockView): diff --git a/dask/bag/core.py b/dask/bag/core.py index 4603ae57cdc..e3a6f675def 100644 --- a/dask/bag/core.py +++ b/dask/bag/core.py @@ -7,7 +7,7 @@ import uuid import warnings from collections import defaultdict -from collections.abc import Iterable, Iterator, Mapping, Sequence +from collections.abc import Iterable, Iterator, Sequence from functools import partial, reduce, wraps from random import Random from urllib.request import urlopen @@ -469,7 +469,7 @@ class Bag(DaskMethodsMixin): 30 """ - def __init__(self, dsk: Mapping, name: str, npartitions: int): + def __init__(self, dsk: Graph, name: str, npartitions: int): if not isinstance(dsk, HighLevelGraph): dsk = HighLevelGraph.from_collections(name, dsk, dependencies=[]) self.dask = dsk diff --git a/dask/bag/tests/test_bag.py b/dask/bag/tests/test_bag.py index ab46224b56b..2926122b412 100644 --- a/dask/bag/tests/test_bag.py +++ b/dask/bag/tests/test_bag.py @@ -34,10 +34,11 @@ from dask.bag.utils import assert_eq from dask.blockwise import Blockwise from dask.delayed import Delayed +from dask.typing import Graph from dask.utils import filetexts, tmpdir, tmpfile from dask.utils_test import add, hlg_layer, hlg_layer_topological, inc -dsk = {("x", 0): (range, 5), ("x", 1): (range, 5), ("x", 2): (range, 5)} +dsk: Graph = {("x", 0): (range, 5), ("x", 1): (range, 5), ("x", 2): (range, 5)} L = list(range(5)) * 3 diff --git a/dask/graph_manipulation.py b/dask/graph_manipulation.py index a80fd8112a8..3c4cbec8d2b 100644 --- a/dask/graph_manipulation.py +++ b/dask/graph_manipulation.py @@ -20,7 +20,7 @@ from dask.core import flatten from dask.delayed import Delayed, delayed from dask.highlevelgraph import HighLevelGraph, Layer, MaterializedLayer -from dask.typing import Key +from dask.typing import Graph, Key __all__ = ("bind", "checkpoint", "clone", "wait_on") @@ -78,7 +78,7 @@ def _checkpoint_one(collection, split_every) -> Delayed: next(keys_iter) except StopIteration: # Collection has 0 or 1 keys; no need for a map step - layer = {name: (chunks.checkpoint, collection.__dask_keys__())} + layer: Graph = {name: (chunks.checkpoint, collection.__dask_keys__())} dsk = HighLevelGraph.from_collections(name, layer, dependencies=(collection,)) return Delayed(name, dsk) diff --git a/dask/highlevelgraph.py b/dask/highlevelgraph.py index 19aa3befe5a..eb13ff013f7 100644 --- a/dask/highlevelgraph.py +++ b/dask/highlevelgraph.py @@ -3,7 +3,18 @@ import abc import copy import html -from collections.abc import Collection, Hashable, Iterable, KeysView, Mapping, Set +from collections.abc import ( + Collection, + Hashable, + ItemsView, + Iterable, + Iterator, + KeysView, + Mapping, + Sequence, + Set, + ValuesView, +) from typing import Any import tlz as toolz @@ -12,7 +23,7 @@ from dask import config from dask.base import clone_key, flatten, is_dask_collection from dask.core import keys_in_tasks, reverse_dict -from dask.typing import Graph, Key +from dask.typing import DaskCollection, Graph, Key from dask.utils import ensure_dict, import_required, key_split from dask.widgets import get_template @@ -34,7 +45,7 @@ def _find_layer_containing_key(key): return ret -class Layer(Mapping): +class Layer(Graph): """High level graph layer This abstract class establish a protocol for high level graph layers. @@ -327,7 +338,7 @@ def get_output_keys(self): return self.keys() -class HighLevelGraph(Mapping): +class HighLevelGraph(Graph): """Task graph composed of layers of dependent subgraphs This object encodes a Dask task graph that is composed of layers of @@ -433,7 +444,12 @@ def _from_collection(cls, name, layer, collection): return cls(layers, deps) @classmethod - def from_collections(cls, name, layer, dependencies=()): + def from_collections( + cls, + name: str, + layer: Graph, + dependencies: Sequence[DaskCollection] = (), + ) -> HighLevelGraph: """Construct a HighLevelGraph from a new layer and a set of collections This constructs a HighLevelGraph in the common case where we have a single @@ -470,34 +486,35 @@ def from_collections(cls, name, layer, dependencies=()): if len(dependencies) == 1: return cls._from_collection(name, layer, dependencies[0]) layers = {name: layer} - deps = {name: set()} + name_dep: set[str] = set() + deps: dict[str, Set[str]] = {name: name_dep} for collection in toolz.unique(dependencies, key=id): if is_dask_collection(collection): graph = collection.__dask_graph__() if isinstance(graph, HighLevelGraph): layers.update(graph.layers) deps.update(graph.dependencies) - deps[name] |= set(collection.__dask_layers__()) + name_dep |= set(collection.__dask_layers__()) else: key = _get_some_layer_name(collection) layers[key] = graph - deps[name].add(key) + name_dep.add(key) deps[key] = set() else: raise TypeError(type(collection)) return cls(layers, deps) - def __getitem__(self, key): + def __getitem__(self, key: Key) -> Any: # Attempt O(1) direct access first, under the assumption that layer names match # either the keys (Scalar, Item, Delayed) or the first element of the key tuples # (Array, Bag, DataFrame, Series). This assumption is not always true. try: - return self.layers[key][key] + return self.layers[key][key] # type: ignore except KeyError: pass try: - return self.layers[key[0]][key] + return self.layers[key[0]][key] # type: ignore except (KeyError, IndexError, TypeError): pass @@ -518,10 +535,10 @@ def __len__(self) -> int: # https://github.com/dask/dask/issues/7271 return sum(len(layer) for layer in self.layers.values()) - def __iter__(self): + def __iter__(self) -> Iterator[Key]: return iter(self.to_dict()) - def to_dict(self) -> dict: + def to_dict(self) -> dict[Key, Any]: """Efficiently convert to plain dict. This method is faster than dict(self).""" try: return self._to_dict @@ -537,7 +554,7 @@ def keys(self) -> KeysView: """ return self.to_dict().keys() - def get_all_external_keys(self) -> set: + def get_all_external_keys(self) -> set[Key]: """Get all output keys of all layers This will in most cases _not_ materialize any layers, which makes @@ -560,10 +577,10 @@ def get_all_external_keys(self) -> set: self._all_external_keys = keys return keys - def items(self): + def items(self) -> ItemsView[Key, Any]: return self.to_dict().items() - def values(self): + def values(self) -> ValuesView[Any]: return self.to_dict().values() def get_all_dependencies(self) -> dict[Key, Set[Key]]: @@ -586,10 +603,10 @@ def get_all_dependencies(self) -> dict[Key, Set[Key]]: return self.key_dependencies @property - def dependents(self): + def dependents(self) -> dict[str, set[str]]: return reverse_dict(self.dependencies) - def copy(self): + def copy(self) -> HighLevelGraph: return HighLevelGraph( ensure_dict(self.layers, copy=True), ensure_dict(self.dependencies, copy=True), @@ -597,16 +614,16 @@ def copy(self): ) @classmethod - def merge(cls, *graphs): - layers = {} - dependencies = {} + def merge(cls, *graphs: Graph) -> HighLevelGraph: + layers: dict[str, Graph] = {} + dependencies: dict[str, Set[str]] = {} for g in graphs: if isinstance(g, HighLevelGraph): layers.update(g.layers) dependencies.update(g.dependencies) elif isinstance(g, Mapping): - layers[id(g)] = g - dependencies[id(g)] = set() + layers[str(id(g))] = g + dependencies[str(id(g))] = set() else: raise TypeError(g) return cls(layers, dependencies) @@ -655,7 +672,7 @@ def visualize(self, filename="dask-hlg.svg", format=None, **kwargs): graphviz_to_file(g, filename, format) return g - def _toposort_layers(self): + def _toposort_layers(self) -> list[str]: """Sort the layers in a high level graph topologically Parameters @@ -669,7 +686,7 @@ def _toposort_layers(self): List of layer names sorted topologically """ degree = {k: len(v) for k, v in self.dependencies.items()} - reverse_deps = {k: [] for k in self.dependencies} + reverse_deps: dict[str, list[str]] = {k: [] for k in self.dependencies} ready = [] for k, v in self.dependencies.items(): for dep in v: @@ -686,7 +703,7 @@ def _toposort_layers(self): ready.append(rdep) return ret - def cull(self, keys: Iterable) -> HighLevelGraph: + def cull(self, keys: Iterable[Key]) -> HighLevelGraph: """Return new HighLevelGraph with only the tasks required to calculate keys. In other words, remove unnecessary tasks from dask. @@ -779,7 +796,7 @@ def cull_layers(self, layers: Iterable[str]) -> HighLevelGraph: return HighLevelGraph(ret_layers, ret_dependencies) - def validate(self): + def validate(self) -> None: # Check dependencies for layer_name, deps in self.dependencies.items(): if layer_name not in self.layers: @@ -820,7 +837,7 @@ def __repr__(self) -> str: representation += f" {i}. {layerkey}\n" return representation - def _repr_html_(self): + def _repr_html_(self) -> str: return get_template("highlevelgraph.html.j2").render( type=type(self).__name__, layers=self.layers, From 62dbbffa398322052aeabb189a3be770af1d51bf Mon Sep 17 00:00:00 2001 From: Michael Leslie Date: Wed, 4 Oct 2023 12:04:37 -0700 Subject: [PATCH 3/6] Allow passing index_col=False in dd.read_csv (#9961) --- dask/dataframe/io/csv.py | 8 +++++--- dask/dataframe/io/tests/test_csv.py | 4 ++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/dask/dataframe/io/csv.py b/dask/dataframe/io/csv.py index 1d5189e8bf6..5a847789c2c 100644 --- a/dask/dataframe/io/csv.py +++ b/dask/dataframe/io/csv.py @@ -486,10 +486,12 @@ def read_pandas( lineterminator = "\n" if include_path_column and isinstance(include_path_column, bool): include_path_column = "path" - if "index" in kwargs or "index_col" in kwargs: + if "index" in kwargs or ( + "index_col" in kwargs and kwargs.get("index_col") is not False + ): raise ValueError( - "Keywords 'index' and 'index_col' not supported. " - f"Use dd.{reader_name}(...).set_index('my-index') instead" + "Keywords 'index' and 'index_col' not supported, except for " + "'index_col=False'. Use dd.{reader_name}(...).set_index('my-index') instead" ) for kw in ["iterator", "chunksize"]: if kw in kwargs: diff --git a/dask/dataframe/io/tests/test_csv.py b/dask/dataframe/io/tests/test_csv.py index 67df4c9df36..1df7202f2d3 100644 --- a/dask/dataframe/io/tests/test_csv.py +++ b/dask/dataframe/io/tests/test_csv.py @@ -1117,6 +1117,10 @@ def test_index_col(): except ValueError as e: assert "set_index" in str(e) + df = pd.read_csv(fn, index_col=False) + ddf = dd.read_csv(fn, blocksize=30, index_col=False) + assert_eq(df, ddf, check_index=False) + def test_read_csv_with_datetime_index_partitions_one(): with filetext(timeseries) as fn: From 928a95aa56f60da33a4e724ea2ca97797c612968 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Fri, 6 Oct 2023 12:48:10 +0200 Subject: [PATCH 4/6] Improve cache hits for tuple keys in `key_split` and intern results (#10547) --- dask/utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dask/utils.py b/dask/utils.py index 6713fb3b86f..e2ebb6a588e 100644 --- a/dask/utils.py +++ b/dask/utils.py @@ -1851,10 +1851,11 @@ def key_split(s): >>> key_split('_(x)') # strips unpleasant characters 'x' """ + # If we convert the key, recurse to utilize LRU cache better if type(s) is bytes: - s = s.decode() + return key_split(s.decode()) if type(s) is tuple: - s = s[0] + return key_split(s[0]) try: words = s.split("-") if not words[0][0].isalpha(): @@ -1873,7 +1874,7 @@ def key_split(s): else: if result[0] == "<": result = result.strip("<>").split()[0].split(".")[-1] - return result + return sys.intern(result) except Exception: return "Other" From 1a9ba0107995eb0c9c7d91c71adf31877b7de440 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Tue, 10 Oct 2023 13:01:41 +0200 Subject: [PATCH 5/6] [dask.order] Reduce memory pressure for multi array reductions by releasing splitter tasks more eagerly (#10535) --- dask/order.py | 330 ++++++++---------- dask/tests/test_order.py | 723 +++++++++++++++++++++++++++++++++++---- 2 files changed, 812 insertions(+), 241 deletions(-) diff --git a/dask/order.py b/dask/order.py index 46e8c314508..457841e49ec 100644 --- a/dask/order.py +++ b/dask/order.py @@ -80,7 +80,7 @@ from collections import defaultdict, namedtuple from math import log -from dask.core import get_dependencies, get_deps, getcycle, reverse_dict +from dask.core import get_dependencies, get_deps, getcycle, istask, reverse_dict def order(dsk, dependencies=None): @@ -110,6 +110,7 @@ def order(dsk, dependencies=None): """ if not dsk: return {} + dsk = dict(dsk) if dependencies is None: dependencies = {k: get_dependencies(dsk, k) for k in dsk} @@ -133,7 +134,19 @@ def order(dsk, dependencies=None): # tree, we skip processing it normally. # See https://github.com/dask/dask/issues/6745 root_nodes = {k for k, v in dependents.items() if not v} - skip_root_node = len(root_nodes) == 1 and len(dsk) > 1 + if len(root_nodes) > 1: + # This is also nice because it makes us robust to difference when + # computing vs persisting collections + root = object() + + def _f(*args, **kwargs): + pass + + dsk[root] = (_f, *root_nodes) + dependencies[root] = root_nodes + o = order(dsk, dependencies) + del o[root] + return o # Leaf nodes. We choose one--the initial node--for each weakly connected subgraph. # Let's calculate the `initial_stack_key` as we determine `init_stack` set. @@ -141,10 +154,8 @@ def order(dsk, dependencies=None): # First prioritize large, tall groups, then prioritize the same as ``dependents_key``. key: ( # at a high-level, work towards a large goal (and prefer tall and narrow) - -max_dependencies, num_dependents - max_heights, # tactically, finish small connected jobs first - min_dependencies, num_dependents - min_heights, # prefer tall and narrow -total_dependents, # take a big step # try to be memory efficient @@ -154,8 +165,8 @@ def order(dsk, dependencies=None): ) for key, num_dependents, ( total_dependents, - min_dependencies, - max_dependencies, + _, + _, min_heights, max_heights, ) in ( @@ -177,6 +188,9 @@ def dependents_key(x): return ( # Focus on being memory-efficient len(dependents[x]) - len(dependencies[x]) + num_needed[x], + # Do we favor deep or shallow branches? + # -1: deep + # +1: shallow -metrics[x][3], # min_heights # tie-breaker StrComparable(x), @@ -190,19 +204,15 @@ def dependencies_key(x): num_dependents = len(dependents[x]) ( total_dependents, - min_dependencies, - max_dependencies, + _, + _, min_heights, max_heights, ) = metrics[x] # Prefer short and narrow instead of tall in narrow, because we're going in # reverse along dependencies. return ( - # at a high-level, work towards a large goal (and prefer short and narrow) - -max_dependencies, num_dependents + max_heights, - # tactically, finish small connected jobs first - min_dependencies, num_dependents + min_heights, # prefer short and narrow -total_dependencies[x], # go where the work is # try to be memory efficient @@ -213,14 +223,11 @@ def dependencies_key(x): StrComparable(x), ) - def finish_now_key(x): - """Determine the order of dependents that are ready to run and be released""" - return (-len(dependencies[x]), StrComparable(x)) - + root_total_dependencies = total_dependencies[list(root_nodes)[0]] # Computing this for all keys can sometimes be relatively expensive :( partition_keys = { key: ( - (min_dependencies - total_dependencies[key] + 1) + (root_total_dependencies - total_dependencies[key] + 1) * (total_dependents - min_heights) ) for key, ( @@ -258,13 +265,10 @@ def finish_now_key(x): # via `partition_key`. A dependent goes to: # 1) `inner_stack` if it's better than our current target, # 2) `next_nodes` if the partition key is lower than it's parent, - # 3) `later_nodes` otherwise. - # When the inner stacks are depleted, we process `next_nodes`. If `next_nodes` is - # empty (and `outer_stacks` is empty`), then we process `later_nodes` the same way. + # When the inner stacks are depleted, we process `next_nodes`. # These dicts use `partition_keys` as keys. We process them by placing the values # in `outer_stack` so that the smallest keys will be processed first. next_nodes = defaultdict(list) - later_nodes = defaultdict(list) # `outer_stack` is used to populate `inner_stacks`. From the time we partition the # dependents of a node, we group them: one list per partition key per parent node. @@ -279,10 +283,7 @@ def finish_now_key(x): # Keep track of nodes that are in `inner_stack` or `inner_stacks` so we don't # process them again. - if skip_root_node: - seen = set(root_nodes) - else: - seen = set() # seen in an inner_stack (and has dependencies) + seen = set(root_nodes) seen_update = seen.update seen_add = seen.add @@ -325,7 +326,6 @@ def finish_now_key(x): # if they so choose? Maybe. However, I'm sensitive to the multithreaded scheduler, # which is heavily dependent on the ordering obtained here. singles = {} - singles_items = singles.items() singles_clear = singles.clear later_singles = [] later_singles_append = later_singles.append @@ -338,13 +338,13 @@ def finish_now_key(x): # 4. later_singles # 5. next_nodes # 6. outer_stack - # 7. later_nodes - # 8. init_stack + # 7. init_stack # alias for speed set_difference = set.difference is_init_sorted = False + while True: while True: # Perform a DFS along dependencies until we complete our tactical goal @@ -353,7 +353,7 @@ def finish_now_key(x): if item in result: continue if num_needed[item]: - if not skip_root_node or item not in root_nodes: + if item not in root_nodes: inner_stack.append(item) deps = set_difference(dependencies[item], result) if 1 < len(deps) < 1000: @@ -365,6 +365,13 @@ def finish_now_key(x): seen_update(deps) if not singles: continue + # Only process singles once the inner_stack is fully + # resolved. This is important because the singles path later + # on verifies that running the single indeed opens an + # opportunity to release soon by comparing the singles + # parent's dependents with the inner_stack(s) + if inner_stack and num_needed[inner_stack[-1]]: + continue process_singles = True else: result[item] = i @@ -372,28 +379,6 @@ def finish_now_key(x): deps = dependents[item] add_to_inner_stack = True - if metrics[item][3] == 1: # min_height - # Don't leave any dangling single nodes! Finish all dependents that are - # ready and are also root nodes. - finish_now = { - dep - for dep in deps - if not dependents[dep] and num_needed[dep] == 1 - } - if finish_now: - deps -= finish_now # Safe to mutate - if len(finish_now) > 1: - finish_now = sorted(finish_now, key=finish_now_key) - for dep in finish_now: - result[dep] = i - i += 1 - add_to_inner_stack = False - elif skip_root_node: - for dep in root_nodes: - num_needed[dep] -= 1 - # Use remove here to complain loudly if our assumptions change - deps.remove(dep) # Safe to mutate - if deps: for dep in deps: num_needed[dep] -= 1 @@ -417,26 +402,6 @@ def finish_now_key(x): dep2 = dependents[single] result[single] = i i += 1 - if metrics[single][3] == 1: # min_height - # Don't leave any dangling single nodes! Finish all dependents that are - # ready and are also root nodes. - finish_now = { - dep - for dep in dep2 - if not dependents[dep] and num_needed[dep] == 1 - } - if finish_now: - dep2 -= finish_now # Safe to mutate - if len(finish_now) > 1: - finish_now = sorted(finish_now, key=finish_now_key) - for dep in finish_now: - result[dep] = i - i += 1 - elif skip_root_node: - for dep in root_nodes: - num_needed[dep] -= 1 - # Use remove here to complain loudly if our assumptions change - dep2.remove(dep) # Safe to mutate if dep2: for dep in dep2: num_needed[dep] -= 1 @@ -460,51 +425,48 @@ def finish_now_key(x): if process_singles and singles: # We gather all dependents of all singles into `deps`, which we then process below. - # A lingering question is: what should we use for `item`? `item_key` is used to - # determine whether each dependent goes to `next_nodes` or `later_nodes`. Currently, - # we use the last value of `item` (i.e., we don't do anything). + deps = set() add_to_inner_stack = True if inner_stack or inner_stacks else False - for single, parent in singles_items: - if single in result: - continue + singles_keys = set_difference(set(singles), result) + + # NOTE: If this was too slow, LIFO would be a decent + # approximation + for single in sorted(singles_keys, key=lambda x: partition_keys[x]): + # We want to run the singles if they are either releasing a + # dependency directly or that they may be releasing a + # dependency once the current critical path / inner_stack is + # walked. + # By using `seen` here this is more permissive since it also + # includes tasks in a future critical path / inner_stacks + # but it would require additional state to make this + # distinction and we don't have enough data to dermine if + # this is worth it. + parent = singles[single] if ( - add_to_inner_stack - and len(set_difference(dependents[parent], result)) > 1 + len( + set_difference( + set_difference(dependents[parent], result), + seen, + ) + ) + > 1 ): later_singles_append(single) continue - while True: dep2 = dependents[single] result[single] = i i += 1 - if metrics[single][3] == 1: # min_height - # Don't leave any dangling single nodes! Finish all dependents that are - # ready and are also root nodes. - finish_now = { - dep - for dep in dep2 - if not dependents[dep] and num_needed[dep] == 1 - } - if finish_now: - dep2 -= finish_now # Safe to mutate - if len(finish_now) > 1: - finish_now = sorted(finish_now, key=finish_now_key) - for dep in finish_now: - result[dep] = i - i += 1 - elif skip_root_node: - for dep in root_nodes: - num_needed[dep] -= 1 - # Use remove here to complain loudly if our assumptions change - dep2.remove(dep) # Safe to mutate if dep2: for dep in dep2: num_needed[dep] -= 1 if add_to_inner_stack: already_seen = dep2 & seen if already_seen: + # This means that the singles path also + # leads to the current or previous strategic + # path if len(dep2) == len(already_seen): if len(already_seen) == 1: (single,) = already_seen @@ -535,8 +497,7 @@ def finish_now_key(x): add_to_inner_stack = False # If inner_stack is empty, then we typically add the best dependent to it. - # However, we don't add to it if we complete a node early via "finish_now" above - # or if a dependent is already on an inner_stack. In this case, we add the + # However, we don't add to it if a dependent is already on an inner_stack. In this case, we add the # dependents (not in an inner_stack) to next_nodes or later_nodes to handle later. # This serves three purposes: # 1. shrink `deps` so that it can be processed faster, @@ -556,30 +517,18 @@ def finish_now_key(x): if len(deps) == 1: # Fast path! We trim down `deps` above hoping to reach here. (dep,) = deps - if not inner_stack: - if add_to_inner_stack: - inner_stack = [dep] - inner_stack_pop = inner_stack.pop - seen_add(dep) - continue - key = partition_keys[dep] - else: - key = partition_keys[dep] - if key < partition_keys[inner_stack[0]]: - # Run before `inner_stack` (change tactical goal!) - inner_stacks_append(inner_stack) - inner_stack = [dep] - inner_stack_pop = inner_stack.pop - seen_add(dep) - continue + if add_to_inner_stack and not inner_stack: + inner_stack = [dep] + inner_stack_pop = inner_stack.pop + seen_add(dep) + continue + key = partition_keys[dep] if not num_needed[dep]: # We didn't put the single dependency on the stack, but we should still # run it soon, because doing so may free its parent. singles[dep] = item - elif key < partition_keys[item]: - next_nodes[key].append(deps) else: - later_nodes[key].append(deps) + next_nodes[key].append(deps) elif len(deps) == 2: # We special-case when len(deps) == 2 so that we may place a dep on singles. # Otherwise, the logic here is the same as when `len(deps) > 2` below. @@ -620,60 +569,43 @@ def finish_now_key(x): later_singles_append(dep2) else: singles[dep2] = item - elif key2 < partition_keys[item]: - next_nodes[key2].append([dep2]) else: - later_nodes[key2].append([dep2]) + next_nodes[key2].append([dep2]) else: item_key = partition_keys[item] - if key2 < item_key: - next_nodes[key].append([dep]) - next_nodes[key2].append([dep2]) - elif key < item_key: - next_nodes[key].append([dep]) - later_nodes[key2].append([dep2]) - else: - later_nodes[key].append([dep]) - later_nodes[key2].append([dep2]) + for k, d in [(key, dep), (key2, dep2)]: + if not num_needed[d]: + if process_singles: + later_singles_append(d) + else: + singles[d] = item + else: + next_nodes[k].append([d]) else: + assert not inner_stack if add_to_inner_stack: + inner_stack = [dep] + inner_stack_pop = inner_stack.pop + seen_add(dep) if not num_needed[dep2]: - inner_stacks_append(inner_stack) - inner_stack = [dep] - inner_stack_pop = inner_stack.pop - seen_add(dep) singles[dep2] = item elif key == key2 and 5 * partition_keys[item] > 22 * key: - inner_stacks_append(inner_stack) inner_stacks_append([dep2]) - inner_stack = [dep] - inner_stack_pop = inner_stack.pop - seen_update(deps) + seen_add(dep2) else: - inner_stacks_append(inner_stack) - inner_stack = [dep] - inner_stack_pop = inner_stack.pop - seen_add(dep) - if key2 < partition_keys[item]: - next_nodes[key2].append([dep2]) - else: - later_nodes[key2].append([dep2]) - else: - item_key = partition_keys[item] - if key2 < item_key: - next_nodes[key].append([dep]) next_nodes[key2].append([dep2]) - elif key < item_key: - next_nodes[key].append([dep]) - later_nodes[key2].append([dep2]) - else: - later_nodes[key].append([dep]) - later_nodes[key2].append([dep2]) + else: + for k, d in [(key, dep), (key2, dep2)]: + next_nodes[k].append([d]) else: # Slow path :(. This requires grouping by partition_key. - dep_pools = defaultdict(list) + dep_pools = defaultdict(set) + possible_singles = defaultdict(set) for dep in deps: - dep_pools[partition_keys[dep]].append(dep) + pkey = partition_keys[dep] + if not num_needed[dep] and not process_singles: + possible_singles[pkey].add(dep) + dep_pools[pkey].add(dep) item_key = partition_keys[item] if inner_stack: # If we have an inner_stack, we need to look for a "better" path @@ -682,10 +614,12 @@ def finish_now_key(x): for key, vals in dep_pools.items(): if key < prev_key: now_keys.append(key) - elif key < item_key: - next_nodes[key].append(vals) else: - later_nodes[key].append(vals) + psingles = possible_singles[key] + for s in psingles: + singles[s] = item + vals -= psingles + next_nodes[key].append(vals) if now_keys: # Run before `inner_stack` (change tactical goal!) inner_stacks_append(inner_stack) @@ -694,7 +628,7 @@ def finish_now_key(x): for key in now_keys: pool = dep_pools[key] if 1 < len(pool) < 100: - pool.sort(key=dependents_key, reverse=True) + pool = sorted(pool, key=dependents_key, reverse=True) inner_stacks_extend([dep] for dep in pool) seen_update(pool) inner_stack = inner_stacks_pop() @@ -706,7 +640,7 @@ def finish_now_key(x): min_key = min(dep_pools) min_pool = dep_pools.pop(min_key) if len(min_pool) == 1: - inner_stack = min_pool + inner_stack = list(min_pool) seen_update(inner_stack) elif ( 10 * item_key > 11 * len(min_pool) * len(min_pool) * min_key @@ -721,7 +655,9 @@ def finish_now_key(x): # what we have easily available. It is obviously very specific to our # choice of partition_key. Dask tests take this route about 40%. if len(min_pool) < 100: - min_pool.sort(key=dependents_key, reverse=True) + min_pool = sorted( + min_pool, key=dependents_key, reverse=True + ) inner_stacks_extend([dep] for dep in min_pool) inner_stack = inner_stacks_pop() seen_update(min_pool) @@ -736,10 +672,11 @@ def finish_now_key(x): inner_stack_pop = inner_stack.pop for key, vals in dep_pools.items(): - if key < item_key: - next_nodes[key].append(vals) - else: - later_nodes[key].append(vals) + psingles = possible_singles[key] + for s in psingles: + singles[s] = item + vals -= psingles + next_nodes[key].append(vals) if len(dependencies) == len(result): break # all done! @@ -765,12 +702,6 @@ def finish_now_key(x): if inner_stacks: continue - if later_nodes: - # You know all those dependents with large keys we've been hanging onto to run "later"? - # Well, "later" has finally come. - next_nodes, later_nodes = later_nodes, next_nodes - continue - # We just finished computing a connected group. # Let's choose the first `item` in the next group to compute. # If we have few large groups left, then it's best to find `item` by taking a minimum. @@ -795,7 +726,7 @@ def finish_now_key(x): init_stack_pop = init_stack.pop is_init_sorted = True - if skip_root_node and item in root_nodes: + if item in root_nodes: item = init_stack_pop() while item in result: @@ -819,8 +750,10 @@ def graph_metrics(dependencies, dependents, total_dependencies): For each key we return: 1. **total_dependents**: The number of keys that can only be run - after this key is run. The root nodes have value 1 while deep child - nodes will have larger values. + after this key is run. + Note that this is only exact for trees. (undirected) cycles will cause + double counting of nodes. Therefore, this metric is an upper bound + approximation. 1 | @@ -1076,3 +1009,40 @@ def diagnostics(dsk, o=None, dependencies=None): for key, val in o.items() } return rv, pressure + + +def _f(): + ... + + +def _convert_task(task): + if istask(task): + assert callable(task[0]) + new_spec = [] + for el in task[1:]: + if isinstance(el, (str, int)): + new_spec.append(el) + elif isinstance(el, tuple): + if istask(el): + new_spec.append(_convert_task(el)) + else: + new_spec.append(el) + elif isinstance(el, list): + new_spec.append([_convert_task(e) for e in el]) + return (_f, *new_spec) + else: + return task + + +def sanitize_dsk(dsk): + """Take a dask graph and replace callables with a dummy function and remove + payload data like numpy arrays, dataframes, etc. + """ + new = {} + for key, values in dsk.items(): + new_key = key + new[new_key] = _convert_task(values) + if get_deps(new) != get_deps(dsk): + # The switch statement in _convert likely dropped some keys + raise RuntimeError("Sanitization failed to preserve topology.") + return new diff --git a/dask/tests/test_order.py b/dask/tests/test_order.py index cd8bca7889b..de43cc14a37 100644 --- a/dask/tests/test_order.py +++ b/dask/tests/test_order.py @@ -3,6 +3,7 @@ import pytest import dask +from dask.base import collections_to_dsk from dask.core import get_deps from dask.order import diagnostics, ndependencies, order from dask.utils_test import add, inc @@ -109,22 +110,19 @@ def test_base_of_reduce_preferred(abcde): assert o[(b, 1)] <= 3 -@pytest.mark.xfail(reason="Can't please 'em all") def test_avoid_upwards_branching(abcde): r""" - a1 - | - a2 - | - a3 d1 - / \ / - b1 c1 - | | - b2 c2 - | - c3 - - Prefer b1 over c1 because it won't stick around waiting for d1 to complete + a1 + | + a2 + | + a3 d1 + / \ / + b1 c1 + | | + b2 c2 + | + c3 """ a, b, c, d, e = abcde dsk = { @@ -135,11 +133,13 @@ def test_avoid_upwards_branching(abcde): (c, 1): (f, (c, 2)), (c, 2): (f, (c, 3)), (d, 1): (f, (c, 1)), + (c, 3): 1, + (b, 2): 1, } o = order(dsk) - assert o[(b, 1)] < o[(c, 1)] + assert o[(d, 1)] < o[(b, 1)] def test_avoid_upwards_branching_complex(abcde): @@ -272,27 +272,24 @@ def test_type_comparisions_ok(abcde): order(dsk) # this doesn't err -def test_prefer_short_dependents(abcde): +def test_favor_longest_critical_path(abcde): r""" - a - | - d b e - \ | / - c + a + | + d b e + \ | / + c - Prefer to finish d and e before starting b. That way c can be released - during the long computations. """ a, b, c, d, e = abcde dsk = {c: (f,), d: (f, c), e: (f, c), b: (f, c), a: (f, b)} o = order(dsk) - assert o[d] < o[b] - assert o[e] < o[b] + assert o[d] > o[b] + assert o[e] > o[b] -@pytest.mark.xfail(reason="This is challenging to do precisely") def test_run_smaller_sections(abcde): r""" aa @@ -301,36 +298,26 @@ def test_run_smaller_sections(abcde): / \ /| | / a c e cc - Prefer to run acb first because then we can get that out of the way """ a, b, c, d, e = abcde aa, bb, cc, dd = (x * 2 for x in [a, b, c, d]) - expected = [a, c, b, e, d, cc, bb, aa, dd] - - log = [] - - def f(x): - def _(*args): - log.append(x) - - return _ - dsk = { - a: (f(a),), - c: (f(c),), - e: (f(e),), - cc: (f(cc),), - b: (f(b), a, c), - d: (f(d), c, e), - bb: (f(bb), cc), - aa: (f(aa), d, bb), - dd: (f(dd), cc), + a: (f,), + c: (f,), + e: (f,), + cc: (f,), + b: (f, a, c), + d: (f, c, e), + bb: (f, cc), + aa: (f, d, bb), + dd: (f, cc), } - - dask.get(dsk, [aa, b, dd]) # trigger computation - - assert log == expected + o = order(dsk) + assert max(diagnostics(dsk)[1]) <= 4 # optimum is 3 + # This is a mildly ambiguous example + # https://github.com/dask/dask/pull/10535/files#r1337528255 + assert (o[aa] < o[a] and o[dd] < o[a]) or (o[b] < o[e] and o[b] < o[cc]) def test_local_parents_of_reduction(abcde): @@ -591,12 +578,18 @@ def test_dont_run_all_dependents_too_early(abcde): """From https://github.com/dask/dask-ml/issues/206#issuecomment-395873372""" a, b, c, d, e = abcde depth = 10 - dsk = {(a, 0): 0, (b, 0): 1, (c, 0): 2, (d, 0): (f, (a, 0), (b, 0), (c, 0))} + dsk = { + (a, 0): (f, 0), + (b, 0): (f, 1), + (c, 0): (f, 2), + (d, 0): (f, (a, 0), (b, 0), (c, 0)), + } for i in range(1, depth): dsk[(b, i)] = (f, (b, 0)) dsk[(c, i)] = (f, (c, 0)) dsk[(d, i)] = (f, (d, i - 1), (b, i), (c, i)) o = order(dsk) + expected = [3, 6, 9, 12, 15, 18, 21, 24, 27, 30] actual = sorted(v for (letter, num), v in o.items() if letter == d) assert expected == actual @@ -723,13 +716,17 @@ def test_order_with_equal_dependents(abcde): This DAG has enough structure to exercise more parts of `order` """ + # Lower pressure is better but this is where we are right now. Important is + # that no variation below should be worse since all variations below should + # reduce to the same graph when optimized/fused. + max_pressure = 11 a, b, c, d, e = abcde dsk = {} abc = [a, b, c, d] for x in abc: dsk.update( { - (x, 0): 0, + (x, 0): (f, 0), (x, 1): (f, (x, 0)), (x, 2, 0): (f, (x, 0)), (x, 2, 1): (f, (x, 1)), @@ -753,7 +750,9 @@ def test_order_with_equal_dependents(abcde): val = o[(x, 6, i, 1)] - o[(x, 6, i, 0)] assert val > 0 # ideally, val == 2 total += val - assert total <= 110 # ideally, this should be 2 * 16 = 32 + assert total <= 56 # ideally, this should be 2 * 16 == 32 + pressure = diagnostics(dsk, o=o)[1] + assert max(pressure) <= max_pressure # Add one to the end of the nine bundles dsk2 = dict(dsk) @@ -764,10 +763,12 @@ def test_order_with_equal_dependents(abcde): total = 0 for x in abc: for i in range(len(abc)): - val = o[(x, 7, i, 0)] - o[(x, 6, i, 1)] + val = o[(x, 6, i, 1)] - o[(x, 7, i, 0)] assert val > 0 # ideally, val == 3 total += val - assert total <= 138 # ideally, this should be 3 * 16 == 48 + assert total <= 75 # ideally, this should be 3 * 16 == 48 + pressure = diagnostics(dsk2, o=o)[1] + assert max(pressure) <= max_pressure # Remove one from each of the nine bundles dsk3 = dict(dsk) @@ -778,10 +779,12 @@ def test_order_with_equal_dependents(abcde): total = 0 for x in abc: for i in range(len(abc)): - val = o[(x, 6, i, 0)] - o[(x, 5, i, 1)] - assert val > 0 # ideally, val == 2 + val = o[(x, 5, i, 1)] - o[(x, 6, i, 0)] + assert val > 0 total += val - assert total <= 98 # ideally, this should be 2 * 16 == 32 + assert total <= 45 # ideally, this should be 2 * 16 == 32 + pressure = diagnostics(dsk3, o=o)[1] + assert max(pressure) <= max_pressure # Remove another one from each of the nine bundles dsk4 = dict(dsk3) @@ -789,10 +792,11 @@ def test_order_with_equal_dependents(abcde): for i in range(len(abc)): del dsk4[(x, 6, i, 0)] o = order(dsk4) - total = 0 + pressure = diagnostics(dsk4, o=o)[1] + assert max(pressure) <= max_pressure for x in abc: for i in range(len(abc)): - assert o[(x, 5, i, 1)] - o[(x, 5, i, 0)] == 1 + assert abs(o[(x, 5, i, 1)] - o[(x, 5, i, 0)]) <= 10 def test_terminal_node_backtrack(): @@ -986,7 +990,7 @@ def cost(deps): def test_diagnostics(abcde): r""" - a1 b1 c2 d1 e1 + a1 b1 c1 d1 e1 /|\ /|\ /|\ /| / / | X | X | X | / / |/ \|/ \|/ \|/ @@ -1005,6 +1009,10 @@ def test_diagnostics(abcde): (d, 1): (f, (d, 0), (e, 0)), (e, 1): (f, (e, 0)), } + o = order(dsk) + assert o[(e, 1)] == len(dsk) - 1 + assert o[(d, 1)] == len(dsk) - 2 + assert o[(c, 1)] == len(dsk) - 3 info, memory_over_time = diagnostics(dsk) assert memory_over_time == [0, 1, 2, 3, 2, 3, 2, 3, 2, 1] assert {key: val.order for key, val in info.items()} == { @@ -1067,3 +1075,596 @@ def test_diagnostics(abcde): (d, 1): 2, (e, 1): 1, } + + +def test_xarray_like_reduction(): + a, b, c, d, e = list("abcde") + + dsk = {} + for ix in range(3): + part = { + # Part1 + (a, 0, ix): (f,), + (a, 1, ix): (f,), + (b, 0, ix): (f, (a, 0, ix)), + (b, 1, ix): (f, (a, 0, ix), (a, 1, ix)), + (b, 2, ix): (f, (a, 1, ix)), + (c, 0, ix): (f, (b, 0, ix)), + (c, 1, ix): (f, (b, 1, ix)), + (c, 2, ix): (f, (b, 2, ix)), + } + dsk.update(part) + for ix in range(3): + dsk.update( + { + (d, ix): (f, (c, ix, 0), (c, ix, 1), (c, ix, 2)), + } + ) + o = order(dsk) + _, pressure = diagnostics(dsk, o=o) + assert max(pressure) <= 9 + + +@pytest.mark.parametrize( + "optimize", + [True, False], +) +def test_array_vs_dataframe(optimize): + xr = pytest.importorskip("xarray") + + import dask.array as da + + size = 5000 + ds = xr.Dataset( + dict( + anom_u=( + ["time", "face", "j", "i"], + da.random.random((size, 1, 987, 1920), chunks=(10, 1, -1, -1)), + ), + anom_v=( + ["time", "face", "j", "i"], + da.random.random((size, 1, 987, 1920), chunks=(10, 1, -1, -1)), + ), + ) + ) + + quad = ds**2 + quad["uv"] = ds.anom_u * ds.anom_v + mean = quad.mean("time") + diag_array = diagnostics(collections_to_dsk([mean], optimize_graph=optimize)) + diag_df = diagnostics( + collections_to_dsk([mean.to_dask_dataframe()], optimize_graph=optimize) + ) + assert max(diag_df[1]) == max(diag_array[1]) + assert max(diag_array[1]) < 50 + + +def test_anom_mean(): + np = pytest.importorskip("numpy") + xr = pytest.importorskip("xarray") + + import dask.array as da + from dask.utils import parse_bytes + + data = da.random.random( + (260, 1310720), + chunks=(1, parse_bytes("10MiB") // 8), + ) + + ngroups = data.shape[0] // 4 + arr = xr.DataArray( + data, + dims=["time", "x"], + coords={"day": ("time", np.arange(data.shape[0]) % ngroups)}, + ) + data = da.random.random((5, 1), chunks=(1, 1)) + + arr = xr.DataArray( + data, + dims=["time", "x"], + coords={"day": ("time", np.arange(5) % 2)}, + ) + + clim = arr.groupby("day").mean(dim="time") + anom = arr.groupby("day") - clim + anom_mean = anom.mean(dim="time") + _, pressure = diagnostics(anom_mean.__dask_graph__()) + assert max(pressure) <= 9 + + +def test_anom_mean_raw(): + dsk = { + ("d", 0, 0): (f, ("a", 0, 0), ("b1", 0, 0)), + ("d", 1, 0): (f, ("a", 1, 0), ("b1", 1, 0)), + ("d", 2, 0): (f, ("a", 2, 0), ("b1", 2, 0)), + ("d", 3, 0): (f, ("a", 3, 0), ("b1", 3, 0)), + ("d", 4, 0): (f, ("a", 4, 0), ("b1", 4, 0)), + ("a", 0, 0): (f, f, "random_sample", None, (1, 1), [], {}), + ("a", 1, 0): (f, f, "random_sample", None, (1, 1), [], {}), + ("a", 2, 0): (f, f, "random_sample", None, (1, 1), [], {}), + ("a", 3, 0): (f, f, "random_sample", None, (1, 1), [], {}), + ("a", 4, 0): (f, f, "random_sample", None, (1, 1), [], {}), + ("e", 0, 0): (f, ("g1", 0)), + ("e", 1, 0): (f, ("g3", 0)), + ("b0", 0, 0): (f, ("a", 0, 0)), + ("b0", 1, 0): (f, ("a", 2, 0)), + ("b0", 2, 0): (f, ("a", 4, 0)), + ("c0", 0, 0): (f, ("b0", 0, 0)), + ("c0", 1, 0): (f, ("b0", 1, 0)), + ("c0", 2, 0): (f, ("b0", 2, 0)), + ("g1", 0): (f, [("c0", 0, 0), ("c0", 1, 0), ("c0", 2, 0)]), + ("b2", 0, 0): (f, ("a", 1, 0)), + ("b2", 1, 0): (f, ("a", 3, 0)), + ("c1", 0, 0): (f, ("b2", 0, 0)), + ("c1", 1, 0): (f, ("b2", 1, 0)), + ("g3", 0): (f, [("c1", 0, 0), ("c1", 1, 0)]), + ("b1", 0, 0): (f, ("e", 0, 0)), + ("b1", 1, 0): (f, ("e", 1, 0)), + ("b1", 2, 0): (f, ("e", 0, 0)), + ("b1", 3, 0): (f, ("e", 1, 0)), + ("b1", 4, 0): (f, ("e", 0, 0)), + ("c2", 0, 0): (f, ("d", 0, 0)), + ("c2", 1, 0): (f, ("d", 1, 0)), + ("c2", 2, 0): (f, ("d", 2, 0)), + ("c2", 3, 0): (f, ("d", 3, 0)), + ("c2", 4, 0): (f, ("d", 4, 0)), + ("f", 0, 0): (f, [("c2", 0, 0), ("c2", 1, 0), ("c2", 2, 0), ("c2", 3, 0)]), + ("f", 1, 0): (f, [("c2", 4, 0)]), + ("g2", 0): (f, [("f", 0, 0), ("f", 1, 0)]), + } + + o = order(dsk) + + # The left hand computation branch should complete before we start loading + # more data + nodes_to_finish_before_loading_more_data = [ + ("f", 1, 0), + ("d", 0, 0), + ("d", 2, 0), + ("d", 4, 0), + ] + for n in nodes_to_finish_before_loading_more_data: + assert o[n] < o[("a", 1, 0)] + assert o[n] < o[("a", 3, 0)] + + +def test_flaky_array_reduction(): + first = { + ("mean_agg-aggregate-10d721567ef5a0d6a0e1afae8a87c066", 0, 0, 0): ( + f, + [ + ("mean_combine-partial-17c7b5c6eed42e203858b3f6dde16003", 0, 0, 0, 0), + ("mean_combine-partial-17c7b5c6eed42e203858b3f6dde16003", 1, 0, 0, 0), + ], + ), + ("mean_combine-partial-17c7b5c6eed42e203858b3f6dde16003", 0, 0, 0, 0): ( + f, + [ + ("mean_chunk-98a32cd9f4fadbed908fffb32e0c9679", 0, 0, 0, 0), + ("mean_chunk-98a32cd9f4fadbed908fffb32e0c9679", 1, 0, 0, 0), + ("mean_chunk-98a32cd9f4fadbed908fffb32e0c9679", 2, 0, 0, 0), + ("mean_chunk-98a32cd9f4fadbed908fffb32e0c9679", 3, 0, 0, 0), + ], + ), + ("mean_combine-partial-17c7b5c6eed42e203858b3f6dde16003", 1, 0, 0, 0): ( + "mean_chunk-mean_combine-partial-17c7b5c6eed42e203858b3f6dde16003", + 1, + 0, + 0, + 0, + ), + ("mean_chunk-98a32cd9f4fadbed908fffb32e0c9679", 0, 0, 0, 0): ( + f, + ("random_sample-e16bcfb15a013023c98a21e2f03d66a9", 0, 0, 0, 0), + ("random_sample-02eaa4a8dbb23fac4db22ad034c401b3", 0, 0, 0, 0), + ), + ("mean_chunk-98a32cd9f4fadbed908fffb32e0c9679", 1, 0, 0, 0): ( + f, + ("random_sample-02eaa4a8dbb23fac4db22ad034c401b3", 1, 0, 0, 0), + ("random_sample-e16bcfb15a013023c98a21e2f03d66a9", 1, 0, 0, 0), + ), + ("mean_chunk-98a32cd9f4fadbed908fffb32e0c9679", 2, 0, 0, 0): ( + f, + ("random_sample-e16bcfb15a013023c98a21e2f03d66a9", 2, 0, 0, 0), + ("random_sample-02eaa4a8dbb23fac4db22ad034c401b3", 2, 0, 0, 0), + ), + ("mean_chunk-98a32cd9f4fadbed908fffb32e0c9679", 3, 0, 0, 0): ( + f, + ("random_sample-e16bcfb15a013023c98a21e2f03d66a9", 3, 0, 0, 0), + ("random_sample-02eaa4a8dbb23fac4db22ad034c401b3", 3, 0, 0, 0), + ), + ("mean_agg-aggregate-fdb340546b01334890192fcfa55fa0d9", 0, 0, 0): ( + f, + [ + ("mean_combine-partial-23adb4747560e6e33afd63c5bb179709", 0, 0, 0, 0), + ("mean_combine-partial-23adb4747560e6e33afd63c5bb179709", 1, 0, 0, 0), + ], + ), + ("mean_combine-partial-23adb4747560e6e33afd63c5bb179709", 0, 0, 0, 0): ( + f, + [ + ("mean_chunk-7edba1c5a284fcec88b9efdda6c2135f", 0, 0, 0, 0), + ("mean_chunk-7edba1c5a284fcec88b9efdda6c2135f", 1, 0, 0, 0), + ("mean_chunk-7edba1c5a284fcec88b9efdda6c2135f", 2, 0, 0, 0), + ("mean_chunk-7edba1c5a284fcec88b9efdda6c2135f", 3, 0, 0, 0), + ], + ), + ("mean_combine-partial-23adb4747560e6e33afd63c5bb179709", 1, 0, 0, 0): ( + "mean_chunk-mean_combine-partial-23adb4747560e6e33afd63c5bb179709", + 1, + 0, + 0, + 0, + ), + ("mean_chunk-7edba1c5a284fcec88b9efdda6c2135f", 0, 0, 0, 0): ( + f, + ("random_sample-02eaa4a8dbb23fac4db22ad034c401b3", 0, 0, 0, 0), + 2, + ), + ("mean_chunk-7edba1c5a284fcec88b9efdda6c2135f", 1, 0, 0, 0): ( + f, + ("random_sample-02eaa4a8dbb23fac4db22ad034c401b3", 1, 0, 0, 0), + 2, + ), + ("mean_chunk-7edba1c5a284fcec88b9efdda6c2135f", 2, 0, 0, 0): ( + f, + ("random_sample-02eaa4a8dbb23fac4db22ad034c401b3", 2, 0, 0, 0), + 2, + ), + ("mean_chunk-7edba1c5a284fcec88b9efdda6c2135f", 3, 0, 0, 0): ( + f, + ("random_sample-02eaa4a8dbb23fac4db22ad034c401b3", 3, 0, 0, 0), + 2, + ), + ("random_sample-02eaa4a8dbb23fac4db22ad034c401b3", 0, 0, 0, 0): (f, 1), + ("random_sample-02eaa4a8dbb23fac4db22ad034c401b3", 1, 0, 0, 0): (f, 1), + ("random_sample-02eaa4a8dbb23fac4db22ad034c401b3", 2, 0, 0, 0): (f, 1), + ("random_sample-02eaa4a8dbb23fac4db22ad034c401b3", 3, 0, 0, 0): (f, 1), + ("random_sample-02eaa4a8dbb23fac4db22ad034c401b3", 4, 0, 0, 0): (f, 1), + ("mean_agg-aggregate-cc19342c8116d616fc6573f5d20b5762", 0, 0, 0): ( + f, + [ + ("mean_combine-partial-0c98c5a4517f58f8268985e7464daace", 0, 0, 0, 0), + ("mean_combine-partial-0c98c5a4517f58f8268985e7464daace", 1, 0, 0, 0), + ], + ), + ("mean_combine-partial-0c98c5a4517f58f8268985e7464daace", 0, 0, 0, 0): ( + f, + [ + ("mean_chunk-540e88b7d9289f6b5461b95a0787af3e", 0, 0, 0, 0), + ("mean_chunk-540e88b7d9289f6b5461b95a0787af3e", 1, 0, 0, 0), + ("mean_chunk-540e88b7d9289f6b5461b95a0787af3e", 2, 0, 0, 0), + ("mean_chunk-540e88b7d9289f6b5461b95a0787af3e", 3, 0, 0, 0), + ], + ), + ("mean_combine-partial-0c98c5a4517f58f8268985e7464daace", 1, 0, 0, 0): ( + "mean_chunk-mean_combine-partial-0c98c5a4517f58f8268985e7464daace", + 1, + 0, + 0, + 0, + ), + ("mean_chunk-540e88b7d9289f6b5461b95a0787af3e", 0, 0, 0, 0): ( + f, + ("random_sample-e16bcfb15a013023c98a21e2f03d66a9", 0, 0, 0, 0), + ), + ("mean_chunk-540e88b7d9289f6b5461b95a0787af3e", 1, 0, 0, 0): ( + f, + ("random_sample-e16bcfb15a013023c98a21e2f03d66a9", 1, 0, 0, 0), + ), + ("mean_chunk-540e88b7d9289f6b5461b95a0787af3e", 2, 0, 0, 0): ( + f, + ("random_sample-e16bcfb15a013023c98a21e2f03d66a9", 2, 0, 0, 0), + ), + ("mean_chunk-540e88b7d9289f6b5461b95a0787af3e", 3, 0, 0, 0): ( + f, + ("random_sample-e16bcfb15a013023c98a21e2f03d66a9", 3, 0, 0, 0), + ), + ("random_sample-e16bcfb15a013023c98a21e2f03d66a9", 0, 0, 0, 0): (f, 1), + ("random_sample-e16bcfb15a013023c98a21e2f03d66a9", 1, 0, 0, 0): (f, 1), + ("random_sample-e16bcfb15a013023c98a21e2f03d66a9", 2, 0, 0, 0): (f, 1), + ("random_sample-e16bcfb15a013023c98a21e2f03d66a9", 3, 0, 0, 0): (f, 1), + ("random_sample-e16bcfb15a013023c98a21e2f03d66a9", 4, 0, 0, 0): (f, 1), + ( + "mean_chunk-mean_combine-partial-17c7b5c6eed42e203858b3f6dde16003", + 1, + 0, + 0, + 0, + ): ( + f, + [ + ( + f, + ("random_sample-e16bcfb15a013023c98a21e2f03d66a9", 4, 0, 0, 0), + ("random_sample-02eaa4a8dbb23fac4db22ad034c401b3", 4, 0, 0, 0), + ) + ], + ), + ( + "mean_chunk-mean_combine-partial-0c98c5a4517f58f8268985e7464daace", + 1, + 0, + 0, + 0, + ): ( + f, + [(f, ("random_sample-e16bcfb15a013023c98a21e2f03d66a9", 4, 0, 0, 0), 2)], + ), + ( + "mean_chunk-mean_combine-partial-23adb4747560e6e33afd63c5bb179709", + 1, + 0, + 0, + 0, + ): ( + f, + [(f, ("random_sample-02eaa4a8dbb23fac4db22ad034c401b3", 4, 0, 0, 0), 2)], + ), + } + + other = { + ("mean_agg-aggregate-e79dd3b9757c9fb2ad7ade96f3f6c814", 0, 0, 0): ( + f, + [ + ("mean_combine-partial-e7d9fd7c132e12007a4b4f62ce443a75", 0, 0, 0, 0), + ("mean_combine-partial-e7d9fd7c132e12007a4b4f62ce443a75", 1, 0, 0, 0), + ], + ), + ("mean_combine-partial-e7d9fd7c132e12007a4b4f62ce443a75", 0, 0, 0, 0): ( + f, + [ + ("mean_chunk-0df65d9a6e168673f32082f59f19576a", 0, 0, 0, 0), + ("mean_chunk-0df65d9a6e168673f32082f59f19576a", 1, 0, 0, 0), + ("mean_chunk-0df65d9a6e168673f32082f59f19576a", 2, 0, 0, 0), + ("mean_chunk-0df65d9a6e168673f32082f59f19576a", 3, 0, 0, 0), + ], + ), + ("mean_combine-partial-e7d9fd7c132e12007a4b4f62ce443a75", 1, 0, 0, 0): ( + "mean_chunk-mean_combine-partial-e7d9fd7c132e12007a4b4f62ce443a75", + 1, + 0, + 0, + 0, + ), + ("mean_chunk-0df65d9a6e168673f32082f59f19576a", 0, 0, 0, 0): ( + f, + ("random_sample-a155d5a37ac5e09ede89c98a3bfcadff", 0, 0, 0, 0), + ("random_sample-241fdbadc062900adc59d1a79c4c41e1", 0, 0, 0, 0), + ), + ("mean_chunk-0df65d9a6e168673f32082f59f19576a", 1, 0, 0, 0): ( + f, + ("random_sample-a155d5a37ac5e09ede89c98a3bfcadff", 1, 0, 0, 0), + ("random_sample-241fdbadc062900adc59d1a79c4c41e1", 1, 0, 0, 0), + ), + ("mean_chunk-0df65d9a6e168673f32082f59f19576a", 2, 0, 0, 0): ( + f, + ("random_sample-a155d5a37ac5e09ede89c98a3bfcadff", 2, 0, 0, 0), + ("random_sample-241fdbadc062900adc59d1a79c4c41e1", 2, 0, 0, 0), + ), + ("mean_chunk-0df65d9a6e168673f32082f59f19576a", 3, 0, 0, 0): ( + f, + ("random_sample-a155d5a37ac5e09ede89c98a3bfcadff", 3, 0, 0, 0), + ("random_sample-241fdbadc062900adc59d1a79c4c41e1", 3, 0, 0, 0), + ), + ("mean_agg-aggregate-c7647920facf0e557f947b7a6626b7be", 0, 0, 0): ( + f, + [ + ("mean_combine-partial-57413f0bb18da78db0f689a096c7fbbf", 0, 0, 0, 0), + ("mean_combine-partial-57413f0bb18da78db0f689a096c7fbbf", 1, 0, 0, 0), + ], + ), + ("mean_combine-partial-57413f0bb18da78db0f689a096c7fbbf", 0, 0, 0, 0): ( + f, + [ + ("mean_chunk-d6bd425ea61739f1eaa71762fe3bbbd7", 0, 0, 0, 0), + ("mean_chunk-d6bd425ea61739f1eaa71762fe3bbbd7", 1, 0, 0, 0), + ("mean_chunk-d6bd425ea61739f1eaa71762fe3bbbd7", 2, 0, 0, 0), + ("mean_chunk-d6bd425ea61739f1eaa71762fe3bbbd7", 3, 0, 0, 0), + ], + ), + ("mean_combine-partial-57413f0bb18da78db0f689a096c7fbbf", 1, 0, 0, 0): ( + "mean_chunk-mean_combine-partial-57413f0bb18da78db0f689a096c7fbbf", + 1, + 0, + 0, + 0, + ), + ("mean_chunk-d6bd425ea61739f1eaa71762fe3bbbd7", 0, 0, 0, 0): ( + f, + ("random_sample-241fdbadc062900adc59d1a79c4c41e1", 0, 0, 0, 0), + 2, + ), + ("mean_chunk-d6bd425ea61739f1eaa71762fe3bbbd7", 1, 0, 0, 0): ( + f, + ("random_sample-241fdbadc062900adc59d1a79c4c41e1", 1, 0, 0, 0), + 2, + ), + ("mean_chunk-d6bd425ea61739f1eaa71762fe3bbbd7", 2, 0, 0, 0): ( + f, + ("random_sample-241fdbadc062900adc59d1a79c4c41e1", 2, 0, 0, 0), + 2, + ), + ("mean_chunk-d6bd425ea61739f1eaa71762fe3bbbd7", 3, 0, 0, 0): ( + f, + ("random_sample-241fdbadc062900adc59d1a79c4c41e1", 3, 0, 0, 0), + 2, + ), + ("random_sample-241fdbadc062900adc59d1a79c4c41e1", 0, 0, 0, 0): ( + f, + "random_sample", + (10, 1, 987, 1920), + [], + ), + ("random_sample-241fdbadc062900adc59d1a79c4c41e1", 1, 0, 0, 0): ( + f, + "random_sample", + (10, 1, 987, 1920), + [], + ), + ("random_sample-241fdbadc062900adc59d1a79c4c41e1", 2, 0, 0, 0): ( + f, + "random_sample", + (10, 1, 987, 1920), + [], + ), + ("random_sample-241fdbadc062900adc59d1a79c4c41e1", 3, 0, 0, 0): ( + f, + "random_sample", + (10, 1, 987, 1920), + [], + ), + ("random_sample-241fdbadc062900adc59d1a79c4c41e1", 4, 0, 0, 0): ( + f, + "random_sample", + (10, 1, 987, 1920), + [], + ), + ("mean_agg-aggregate-05071ebaabb68a64c180f6f443c5c8f4", 0, 0, 0): ( + f, + [ + ("mean_combine-partial-a7c475f79a46af4265b189ffdc000bb3", 0, 0, 0, 0), + ("mean_combine-partial-a7c475f79a46af4265b189ffdc000bb3", 1, 0, 0, 0), + ], + ), + ("mean_combine-partial-a7c475f79a46af4265b189ffdc000bb3", 0, 0, 0, 0): ( + f, + [ + ("mean_chunk-fd17feaf0728ea7a89d119d3fd172c75", 0, 0, 0, 0), + ("mean_chunk-fd17feaf0728ea7a89d119d3fd172c75", 1, 0, 0, 0), + ("mean_chunk-fd17feaf0728ea7a89d119d3fd172c75", 2, 0, 0, 0), + ("mean_chunk-fd17feaf0728ea7a89d119d3fd172c75", 3, 0, 0, 0), + ], + ), + ("mean_combine-partial-a7c475f79a46af4265b189ffdc000bb3", 1, 0, 0, 0): ( + "mean_chunk-mean_combine-partial-a7c475f79a46af4265b189ffdc000bb3", + 1, + 0, + 0, + 0, + ), + ("mean_chunk-fd17feaf0728ea7a89d119d3fd172c75", 0, 0, 0, 0): ( + f, + ("random_sample-a155d5a37ac5e09ede89c98a3bfcadff", 0, 0, 0, 0), + 2, + ), + ("mean_chunk-fd17feaf0728ea7a89d119d3fd172c75", 1, 0, 0, 0): ( + f, + ("random_sample-a155d5a37ac5e09ede89c98a3bfcadff", 1, 0, 0, 0), + 2, + ), + ("mean_chunk-fd17feaf0728ea7a89d119d3fd172c75", 2, 0, 0, 0): ( + f, + ("random_sample-a155d5a37ac5e09ede89c98a3bfcadff", 2, 0, 0, 0), + 2, + ), + ("mean_chunk-fd17feaf0728ea7a89d119d3fd172c75", 3, 0, 0, 0): ( + f, + ("random_sample-a155d5a37ac5e09ede89c98a3bfcadff", 3, 0, 0, 0), + 2, + ), + ("random_sample-a155d5a37ac5e09ede89c98a3bfcadff", 0, 0, 0, 0): ( + f, + "random_sample", + (10, 1, 987, 1920), + [], + ), + ("random_sample-a155d5a37ac5e09ede89c98a3bfcadff", 1, 0, 0, 0): ( + f, + "random_sample", + (10, 1, 987, 1920), + [], + ), + ("random_sample-a155d5a37ac5e09ede89c98a3bfcadff", 2, 0, 0, 0): ( + f, + "random_sample", + (10, 1, 987, 1920), + [], + ), + ("random_sample-a155d5a37ac5e09ede89c98a3bfcadff", 3, 0, 0, 0): ( + f, + "random_sample", + (10, 1, 987, 1920), + [], + ), + ("random_sample-a155d5a37ac5e09ede89c98a3bfcadff", 4, 0, 0, 0): ( + f, + "random_sample", + (10, 1, 987, 1920), + [], + ), + ( + "mean_chunk-mean_combine-partial-a7c475f79a46af4265b189ffdc000bb3", + 1, + 0, + 0, + 0, + ): ( + f, + [(f, ("random_sample-a155d5a37ac5e09ede89c98a3bfcadff", 4, 0, 0, 0), 2)], + ), + ( + "mean_chunk-mean_combine-partial-e7d9fd7c132e12007a4b4f62ce443a75", + 1, + 0, + 0, + 0, + ): ( + f, + [ + ( + f, + ("random_sample-a155d5a37ac5e09ede89c98a3bfcadff", 4, 0, 0, 0), + ("random_sample-241fdbadc062900adc59d1a79c4c41e1", 4, 0, 0, 0), + ) + ], + ), + ( + "mean_chunk-mean_combine-partial-57413f0bb18da78db0f689a096c7fbbf", + 1, + 0, + 0, + 0, + ): ( + f, + [(f, ("random_sample-241fdbadc062900adc59d1a79c4c41e1", 4, 0, 0, 0), 2)], + ), + } + first_pressure = max(diagnostics(first)[1]) + second_pressure = max(diagnostics(other)[1]) + assert first_pressure == second_pressure + + +def test_flox_reduction(): + # TODO: It would be nice to scramble keys to ensure we're not comparing keys + dsk = { + "A0": (f, 1), + ("A1", 0): (f, 1), + ("A1", 1): (f, 1), + ("A2", 0): (f, 1), + ("A2", 1): (f, 1), + ("B1", 0): (f, [(f, ("A2", 0))]), + ("B1", 1): (f, [(f, ("A2", 1))]), + ("B2", 1): (f, [(f, ("A1", 1))]), + ("B2", 0): (f, [(f, ("A1", 0))]), + ("B11", 0): ("B1", 0), + ("B11", 1): ("B1", 1), + ("B22", 0): ("B2", 0), + ("B22", 1): ("B2", 1), + ("C1", 0): (f, ("B22", 0)), + ("C1", 1): (f, ("B22", 1)), + ("C2", 0): (f, ("B11", 0)), + ("C2", 1): (f, ("B11", 1)), + ("E", 1): (f, [(f, ("A1", 1), ("A2", 1), ("C1", 1), ("C2", 1))]), + ("E", 0): (f, [(f, ("A1", 0), ("A2", 0), ("C1", 0), ("C2", 0))]), + ("EE", 0): ("E", 0), + ("EE", 1): ("E", 1), + ("F1", 0): (f, "A0", ("B11", 0)), + ("F1", 1): (f, "A0", ("B22", 0)), + ("F1", 2): (f, "A0", ("EE", 0)), + ("F2", 0): (f, "A0", ("B11", 1)), + ("F2", 1): (f, "A0", ("B22", 1)), + ("F2", 2): (f, "A0", ("EE", 1)), + } + o = order(dsk) + assert max(o[("F1", ix)] for ix in range(3)) < min(o[("F2", ix)] for ix in range(3)) From 9f7d55705f0eb74a172f5fdc3690aa65fbe21762 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Thu, 12 Oct 2023 10:11:07 +0200 Subject: [PATCH 6/6] Add typing to dask.order (#10553) --- dask/core.py | 45 ++++++++-- dask/graph_manipulation.py | 6 +- dask/highlevelgraph.py | 16 ++-- dask/order.py | 172 ++++++++++++++++++++++--------------- dask/typing.py | 2 + pyproject.toml | 8 ++ 6 files changed, 162 insertions(+), 87 deletions(-) diff --git a/dask/core.py b/dask/core.py index 8c590875105..ef75ca563cd 100644 --- a/dask/core.py +++ b/dask/core.py @@ -1,10 +1,10 @@ from __future__ import annotations from collections import defaultdict -from collections.abc import Collection, Iterable -from typing import Any, cast +from collections.abc import Collection, Iterable, Mapping +from typing import Any, Literal, TypeVar, cast, overload -from dask.typing import Key, no_default +from dask.typing import Graph, Key, NoDefault, no_default def ishashable(x): @@ -223,7 +223,32 @@ def validate_key(key: object) -> None: raise TypeError(f"Unexpected key type {type(key)} (value: {key!r})") -def get_dependencies(dsk, key=None, task=no_default, as_list=False): +@overload +def get_dependencies( + dsk: Graph, + key: Key | None = ..., + task: Key | NoDefault = ..., + as_list: Literal[False] = ..., +) -> set[Key]: + ... + + +@overload +def get_dependencies( + dsk: Graph, + key: Key | None, + task: Key | NoDefault, + as_list: Literal[True], +) -> list[Key]: + ... + + +def get_dependencies( + dsk: Graph, + key: Key | None = None, + task: Key | NoDefault = no_default, + as_list: bool = False, +) -> set[Key] | list[Key]: """Get the immediate tasks on which this task depends Examples @@ -264,7 +289,7 @@ def get_dependencies(dsk, key=None, task=no_default, as_list=False): return keys_in_tasks(dsk, [arg], as_list=as_list) -def get_deps(dsk): +def get_deps(dsk: Graph) -> tuple[dict[Key, set[Key]], dict[Key, set[Key]]]: """Get dependencies and dependents from dask dask graph >>> inc = lambda x: x + 1 @@ -308,7 +333,10 @@ def flatten(seq, container=list): yield item -def reverse_dict(d): +T_ = TypeVar("T_") + + +def reverse_dict(d: Mapping[T_, Iterable[T_]]) -> dict[T_, set[T_]]: """ >>> a, b, c = 'abc' @@ -316,14 +344,13 @@ def reverse_dict(d): >>> reverse_dict(d) # doctest: +SKIP {'a': set([]), 'b': set(['a']}, 'c': set(['a', 'b'])} """ - result = defaultdict(set) + result: defaultdict[T_, set[T_]] = defaultdict(set) _add = set.add for k, vals in d.items(): result[k] for val in vals: _add(result[val], k) - result.default_factory = None - return result + return dict(result) def subs(task, key, val): diff --git a/dask/graph_manipulation.py b/dask/graph_manipulation.py index 3c4cbec8d2b..d712fdea98d 100644 --- a/dask/graph_manipulation.py +++ b/dask/graph_manipulation.py @@ -5,8 +5,8 @@ from __future__ import annotations import uuid -from collections.abc import Callable, Hashable, Set -from typing import Any, Literal, TypeVar +from collections.abc import Callable, Hashable +from typing import Literal, TypeVar from dask.base import ( clone_key, @@ -321,7 +321,7 @@ def _bind_one( dsk = child.__dask_graph__() # type: ignore new_layers: dict[str, Layer] = {} - new_deps: dict[str, Set[Any]] = {} + new_deps: dict[str, set[str]] = {} if isinstance(dsk, HighLevelGraph): try: diff --git a/dask/highlevelgraph.py b/dask/highlevelgraph.py index eb13ff013f7..0f13212ce03 100644 --- a/dask/highlevelgraph.py +++ b/dask/highlevelgraph.py @@ -96,7 +96,7 @@ def is_materialized(self) -> bool: return True @abc.abstractmethod - def get_output_keys(self) -> Set: + def get_output_keys(self) -> Set[Key]: """Return a set of all output keys Output keys are all keys in the layer that might be referenced by @@ -405,16 +405,16 @@ class HighLevelGraph(Graph): """ layers: Mapping[str, Layer] - dependencies: Mapping[str, Set[str]] - key_dependencies: dict[Key, Set[Key]] + dependencies: Mapping[str, set[str]] + key_dependencies: dict[Key, set[Key]] _to_dict: dict _all_external_keys: set def __init__( self, layers: Mapping[str, Graph], - dependencies: Mapping[str, Set[str]], - key_dependencies: dict[Key, Set[Key]] | None = None, + dependencies: Mapping[str, set[str]], + key_dependencies: dict[Key, set[Key]] | None = None, ): self.dependencies = dependencies self.key_dependencies = key_dependencies or {} @@ -487,7 +487,7 @@ def from_collections( return cls._from_collection(name, layer, dependencies[0]) layers = {name: layer} name_dep: set[str] = set() - deps: dict[str, Set[str]] = {name: name_dep} + deps: dict[str, set[str]] = {name: name_dep} for collection in toolz.unique(dependencies, key=id): if is_dask_collection(collection): graph = collection.__dask_graph__() @@ -583,7 +583,7 @@ def items(self) -> ItemsView[Key, Any]: def values(self) -> ValuesView[Any]: return self.to_dict().values() - def get_all_dependencies(self) -> dict[Key, Set[Key]]: + def get_all_dependencies(self) -> dict[Key, set[Key]]: """Get dependencies of all keys This will in most cases materialize all layers, which makes @@ -616,7 +616,7 @@ def copy(self) -> HighLevelGraph: @classmethod def merge(cls, *graphs: Graph) -> HighLevelGraph: layers: dict[str, Graph] = {} - dependencies: dict[str, Set[str]] = {} + dependencies: dict[str, set[str]] = {} for g in graphs: if isinstance(g, HighLevelGraph): layers.update(g.layers) diff --git a/dask/order.py b/dask/order.py index 457841e49ec..9a867866027 100644 --- a/dask/order.py +++ b/dask/order.py @@ -78,12 +78,18 @@ good proxy for ordering. This is usually a good idea and a sane default. """ from collections import defaultdict, namedtuple +from collections.abc import Mapping, MutableMapping from math import log +from typing import Any, cast from dask.core import get_dependencies, get_deps, getcycle, istask, reverse_dict +from dask.typing import Key -def order(dsk, dependencies=None): +def order( + dsk: MutableMapping[Key, Any], + dependencies: MutableMapping[Key, set[Key]] | None = None, +) -> dict[Key, int]: """Order nodes in dask graph This produces an ordering over our tasks that we use to break ties when @@ -110,11 +116,11 @@ def order(dsk, dependencies=None): """ if not dsk: return {} + dsk = dict(dsk) if dependencies is None: dependencies = {k: get_dependencies(dsk, k) for k in dsk} - dependents = reverse_dict(dependencies) num_needed, total_dependencies = ndependencies(dependencies, dependents) metrics = graph_metrics(dependencies, dependents, total_dependencies) @@ -137,9 +143,9 @@ def order(dsk, dependencies=None): if len(root_nodes) > 1: # This is also nice because it makes us robust to difference when # computing vs persisting collections - root = object() + root = cast(Key, object()) - def _f(*args, **kwargs): + def _f(*args: Any, **kwargs: Any) -> None: pass dsk[root] = (_f, *root_nodes) @@ -148,6 +154,7 @@ def _f(*args, **kwargs): del o[root] return o + init_stack: dict[Key, tuple] | set[Key] | list[Key] # Leaf nodes. We choose one--the initial node--for each weakly connected subgraph. # Let's calculate the `initial_stack_key` as we determine `init_stack` set. init_stack = { @@ -179,12 +186,14 @@ def _f(*args, **kwargs): # This value is static, so we pre-compute as the value of this dict. initial_stack_key = init_stack.__getitem__ - def dependents_key(x): + def dependents_key(x: Key) -> tuple: """Choose a path from our starting task to our tactical goal This path is connected to a large goal, but focuses on completing a small goal and being memory efficient. """ + assert dependencies is not None + return ( # Focus on being memory-efficient len(dependents[x]) - len(dependencies[x]) + num_needed[x], @@ -196,11 +205,12 @@ def dependents_key(x): StrComparable(x), ) - def dependencies_key(x): + def dependencies_key(x: Key) -> tuple: """Choose which dependency to run as part of a reverse DFS This is very similar to both ``initial_stack_key``. """ + assert dependencies is not None num_dependents = len(dependents[x]) ( total_dependents, @@ -232,14 +242,14 @@ def dependencies_key(x): ) for key, ( total_dependents, - min_dependencies, + _, _, min_heights, _, ) in metrics.items() } - result = {} + result: dict[Key, int] = {} i = 0 # `inner_stack` is used to perform a DFS along dependencies. Once emptied @@ -254,7 +264,7 @@ def dependencies_key(x): # A "better path" is determined by comparing `partition_keys`. inner_stack = [min(init_stack, key=initial_stack_key)] inner_stack_pop = inner_stack.pop - inner_stacks = [] + inner_stacks: list[list[Key]] = [] inner_stacks_append = inner_stacks.append inner_stacks_extend = inner_stacks.extend inner_stacks_pop = inner_stacks.pop @@ -268,7 +278,7 @@ def dependencies_key(x): # When the inner stacks are depleted, we process `next_nodes`. # These dicts use `partition_keys` as keys. We process them by placing the values # in `outer_stack` so that the smallest keys will be processed first. - next_nodes = defaultdict(list) + next_nodes: defaultdict[int, list[list[Key] | set[Key]]] = defaultdict(list) # `outer_stack` is used to populate `inner_stacks`. From the time we partition the # dependents of a node, we group them: one list per partition key per parent node. @@ -277,7 +287,7 @@ def dependencies_key(x): # partitioned, and we keep them in the order that we saw them (we will process them # in a FIFO manner). By delaying sorting for as long as we can, we can first filter # out nodes that have already been computed. All this complexity is worth it! - outer_stack = [] + outer_stack: list[list[Key]] = [] outer_stack_extend = outer_stack.extend outer_stack_pop = outer_stack.pop @@ -325,9 +335,9 @@ def dependencies_key(x): # scheduler? Should we defer to dynamic schedulers and let them behave like this # if they so choose? Maybe. However, I'm sensitive to the multithreaded scheduler, # which is heavily dependent on the ordering obtained here. - singles = {} + singles: dict[Key, Key] = {} singles_clear = singles.clear - later_singles = [] + later_singles: list[Key] = [] later_singles_append = later_singles.append later_singles_clear = later_singles.clear @@ -348,6 +358,8 @@ def dependencies_key(x): while True: while True: # Perform a DFS along dependencies until we complete our tactical goal + deps = set() + add_to_inner_stack = True if inner_stack: item = inner_stack_pop() if item in result: @@ -394,25 +406,25 @@ def dependencies_key(x): elif later_singles: # No need to be optimistic: all nodes in `later_singles` will free a dependency # when run, so no need to check whether dependents are in `seen`. - deps = set() for single in later_singles: if single in result: continue while True: - dep2 = dependents[single] + deps_singles = dependents[single] result[single] = i i += 1 - if dep2: - for dep in dep2: + if deps_singles: + for dep in deps_singles: num_needed[dep] -= 1 - if len(dep2) == 1: + if len(deps_singles) == 1: # Fast path! We trim down `dep2` above hoping to reach here. - (single,) = dep2 + (single,) = deps_singles if not num_needed[single]: # Keep it going! - dep2 = dependents[single] + deps_singles = dependents[single] continue - deps |= dep2 + deps |= deps_singles + del deps_singles break later_singles_clear() deps = set_difference(deps, result) @@ -426,7 +438,6 @@ def dependencies_key(x): if process_singles and singles: # We gather all dependents of all singles into `deps`, which we then process below. - deps = set() add_to_inner_stack = True if inner_stack or inner_stacks else False singles_keys = set_difference(set(singles), result) @@ -455,41 +466,42 @@ def dependencies_key(x): later_singles_append(single) continue while True: - dep2 = dependents[single] + deps_singles = dependents[single] result[single] = i i += 1 - if dep2: - for dep in dep2: + if deps_singles: + for dep in deps_singles: num_needed[dep] -= 1 if add_to_inner_stack: - already_seen = dep2 & seen + already_seen = deps_singles & seen if already_seen: # This means that the singles path also # leads to the current or previous strategic # path - if len(dep2) == len(already_seen): + if len(deps_singles) == len(already_seen): if len(already_seen) == 1: (single,) = already_seen if not num_needed[single]: - dep2 = dependents[single] + deps_singles = dependents[single] continue break - dep2 = dep2 - already_seen + deps_singles = deps_singles - already_seen else: - already_seen = False - if len(dep2) == 1: + already_seen = set() + if len(deps_singles) == 1: # Fast path! We trim down `dep2` above hoping to reach here. - (single,) = dep2 + (single,) = deps_singles if not num_needed[single]: if not already_seen: # Keep it going! - dep2 = dependents[single] + deps_singles = dependents[single] continue later_singles_append(single) break - deps |= dep2 + deps |= deps_singles + del deps_singles break - + del singles_keys deps = set_difference(deps, result) singles_clear() if not deps: @@ -510,6 +522,7 @@ def dependencies_key(x): (dep,) = already_seen if not num_needed[dep]: singles[dep] = item + del dep continue add_to_inner_stack = False deps = deps - already_seen @@ -529,6 +542,7 @@ def dependencies_key(x): singles[dep] = item else: next_nodes[key].append(deps) + del dep, key elif len(deps) == 2: # We special-case when len(deps) == 2 so that we may place a dep on singles. # Otherwise, the logic here is the same as when `len(deps) > 2` below. @@ -581,6 +595,8 @@ def dependencies_key(x): singles[d] = item else: next_nodes[k].append([d]) + del item_key + del prev_key else: assert not inner_stack if add_to_inner_stack: @@ -597,6 +613,7 @@ def dependencies_key(x): else: for k, d in [(key, dep), (key2, dep2)]: next_nodes[k].append([d]) + del dep, dep2, key, key2 else: # Slow path :(. This requires grouping by partition_key. dep_pools = defaultdict(set) @@ -611,6 +628,7 @@ def dependencies_key(x): # If we have an inner_stack, we need to look for a "better" path prev_key = partition_keys[inner_stack[0]] now_keys = [] # < inner_stack[0] + psingles = set() for key, vals in dep_pools.items(): if key < prev_key: now_keys.append(key) @@ -620,23 +638,29 @@ def dependencies_key(x): singles[s] = item vals -= psingles next_nodes[key].append(vals) + del vals, key + del psingles if now_keys: # Run before `inner_stack` (change tactical goal!) inner_stacks_append(inner_stack) if 1 < len(now_keys): now_keys.sort(reverse=True) for key in now_keys: + pool: set[Key] | list[Key] pool = dep_pools[key] if 1 < len(pool) < 100: pool = sorted(pool, key=dependents_key, reverse=True) inner_stacks_extend([dep] for dep in pool) seen_update(pool) + del pool inner_stack = inner_stacks_pop() inner_stack_pop = inner_stack.pop + del now_keys, prev_key else: # If we don't have an inner_stack, then we don't need to look # for a "better" path, but we do need traverse along dependents. if add_to_inner_stack: + min_pool: list[Key] | set[Key] min_key = min(dep_pools) min_pool = dep_pools.pop(min_key) if len(min_pool) == 1: @@ -669,7 +693,7 @@ def dependencies_key(x): inner_stack = [min_pool.pop()] next_nodes[min_key].append(min_pool) seen_update(inner_stack) - + del min_pool, min_key inner_stack_pop = inner_stack.pop for key, vals in dep_pools.items(): psingles = possible_singles[key] @@ -677,6 +701,7 @@ def dependencies_key(x): singles[s] = item vals -= psingles next_nodes[key].append(vals) + del key, vals if len(dependencies) == len(result): break # all done! @@ -686,18 +711,20 @@ def dependencies_key(x): # `outer_stacks` may not be empty here--it has data from previous `next_nodes`. # Since we pop things off of it (onto `inner_nodes`), this means we handle # multiple `next_nodes` in a LIFO manner. - outer_stack_extend(reversed(next_nodes[key])) - next_nodes = defaultdict(list) + outer_stack_extend(list(el) for el in reversed(next_nodes[key])) + next_nodes.clear() + outer_deps = [] while outer_stack: # Try to add a few items to `inner_stacks` - deps = [x for x in outer_stack_pop() if x not in result] - if deps: - if 1 < len(deps) < 100: - deps.sort(key=dependents_key, reverse=True) - inner_stacks_extend([dep] for dep in deps) - seen_update(deps) + outer_deps = [x for x in outer_stack_pop() if x not in result] + if outer_deps: + if 1 < len(outer_deps) < 100: + outer_deps.sort(key=dependents_key, reverse=True) + inner_stacks_extend([dep] for dep in outer_deps) + seen_update(outer_deps) break + del outer_deps if inner_stacks: continue @@ -709,8 +736,7 @@ def dependencies_key(x): # If we have many tiny groups left, then it's best to simply iterate. if not is_init_sorted: prev_len = len(init_stack) - if type(init_stack) is dict: - init_stack = set(init_stack) + init_stack = set(init_stack) init_stack = set_difference(init_stack, result) N = len(init_stack) m = prev_len - N @@ -736,7 +762,11 @@ def dependencies_key(x): return result -def graph_metrics(dependencies, dependents, total_dependencies): +def graph_metrics( + dependencies: Mapping[Key, set[Key]], + dependents: Mapping[Key, set[Key]], + total_dependencies: Mapping[Key, int], +) -> dict[Key, tuple[int, int, int, int, int]]: r"""Useful measures of a graph used by ``dask.order.order`` Example DAG (a1 has no dependencies; b2 and c1 are root nodes): @@ -815,7 +845,7 @@ def graph_metrics(dependencies, dependents, total_dependencies): """ result = {} num_needed = {k: len(v) for k, v in dependents.items() if v} - current = [] + current: list[Key] = [] current_pop = current.pop current_append = current.append for key, deps in dependents.items(): @@ -848,18 +878,18 @@ def graph_metrics(dependencies, dependents, total_dependencies): ) else: ( - total_dependents, - min_dependencies, - max_dependencies, - min_heights, - max_heights, + total_dependents_, + min_dependencies_, + max_dependencies_, + min_heights_, + max_heights_, ) = zip(*(result[parent] for parent in dependents[key])) result[key] = ( - 1 + sum(total_dependents), - min(min_dependencies), - max(max_dependencies), - 1 + min(min_heights), - 1 + max(max_heights), + 1 + sum(total_dependents_), + min(min_dependencies_), + max(max_dependencies_), + 1 + min(min_heights_), + 1 + max(max_heights_), ) for child in dependencies[key]: num_needed[child] -= 1 @@ -868,7 +898,9 @@ def graph_metrics(dependencies, dependents, total_dependencies): return result -def ndependencies(dependencies, dependents): +def ndependencies( + dependencies: Mapping[Key, set[Key]], dependents: Mapping[Key, set[Key]] +) -> tuple[dict[Key, int], dict[Key, int]]: """Number of total data elements on which this key depends For each key we return the number of tasks that must be run for us to run @@ -896,7 +928,7 @@ def ndependencies(dependencies, dependents): result[k] = 1 num_dependencies = num_needed.copy() - current = [] + current: list[Key] = [] current_pop = current.pop current_append = current.append @@ -934,10 +966,12 @@ class StrComparable: __slots__ = ("obj",) - def __init__(self, obj): + obj: Any + + def __init__(self, obj: Any): self.obj = obj - def __lt__(self, other): + def __lt__(self, other: Any) -> bool: try: return self.obj < other.obj except Exception: @@ -956,7 +990,11 @@ def __lt__(self, other): ) -def diagnostics(dsk, o=None, dependencies=None): +def diagnostics( + dsk: MutableMapping[Key, Any], + o: Mapping[Key, int] | None = None, + dependencies: MutableMapping[Key, set[Key]] | None = None, +) -> tuple[dict[Key, OrderInfo], list[int]]: """Simulate runtime metrics as though running tasks one at a time in order. These diagnostics can help reveal behaviors of and issues with ``order``. @@ -1011,14 +1049,14 @@ def diagnostics(dsk, o=None, dependencies=None): return rv, pressure -def _f(): +def _f() -> None: ... -def _convert_task(task): +def _convert_task(task: Any) -> Any: if istask(task): assert callable(task[0]) - new_spec = [] + new_spec: list[Any] = [] for el in task[1:]: if isinstance(el, (str, int)): new_spec.append(el) @@ -1034,7 +1072,7 @@ def _convert_task(task): return task -def sanitize_dsk(dsk): +def sanitize_dsk(dsk: MutableMapping[Key, Any]) -> dict: """Take a dask graph and replace callables with a dummy function and remove payload data like numpy arrays, dataframes, etc. """ diff --git a/dask/typing.py b/dask/typing.py index 5cb71bd759e..5a47f2d8516 100644 --- a/dask/typing.py +++ b/dask/typing.py @@ -26,6 +26,8 @@ Key: TypeAlias = Union[str, bytes, int, float, tuple["Key", ...]] +# FIXME: This type is a little misleading. Low level graphs are often +# MutableMappings but HLGs are not Graph: TypeAlias = Mapping[Key, Any] # Potentially nested list of Dask keys NestedKeys: TypeAlias = list[Union[Key, "NestedKeys"]] diff --git a/pyproject.toml b/pyproject.toml index c2903525307..e938d7aa349 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -166,6 +166,14 @@ warn_redundant_casts = true warn_unused_ignores = true warn_unreachable = true +[[tool.mypy.overrides]] + +# Recent or recently overhauled modules featuring stricter validation +module = [ + "dask.order", +] +allow_untyped_defs = false + [tool.codespell] ignore-words-list = "coo,nd" skip = "docs/source/changelog.rst"