Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into shuffle-drop_duplic…
Browse files Browse the repository at this point in the history
…ates
  • Loading branch information
rjzamora committed Oct 12, 2023
2 parents 2b0d866 + 9f7d557 commit ec83b51
Show file tree
Hide file tree
Showing 14 changed files with 1,068 additions and 363 deletions.
6 changes: 3 additions & 3 deletions dask/array/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5803,10 +5803,10 @@ def __getitem__(self, index: Any) -> Array:

keys = product(*(range(len(c)) for c in chunks))

layer = {(name,) + key: tuple(new_keys[key].tolist()) for key in keys}
graph: Graph = {(name,) + key: tuple(new_keys[key].tolist()) for key in keys}

graph = HighLevelGraph.from_collections(name, layer, dependencies=[self._array])
return Array(graph, name, chunks, meta=self._array)
hlg = HighLevelGraph.from_collections(name, graph, dependencies=[self._array])
return Array(hlg, name, chunks, meta=self._array)

def __eq__(self, other: Any) -> bool:
if isinstance(other, BlockView):
Expand Down
4 changes: 2 additions & 2 deletions dask/bag/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import uuid
import warnings
from collections import defaultdict
from collections.abc import Iterable, Iterator, Mapping, Sequence
from collections.abc import Iterable, Iterator, Sequence
from functools import partial, reduce, wraps
from random import Random
from urllib.request import urlopen
Expand Down Expand Up @@ -469,7 +469,7 @@ class Bag(DaskMethodsMixin):
30
"""

def __init__(self, dsk: Mapping, name: str, npartitions: int):
def __init__(self, dsk: Graph, name: str, npartitions: int):
if not isinstance(dsk, HighLevelGraph):
dsk = HighLevelGraph.from_collections(name, dsk, dependencies=[])
self.dask = dsk
Expand Down
3 changes: 2 additions & 1 deletion dask/bag/tests/test_bag.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@
from dask.bag.utils import assert_eq
from dask.blockwise import Blockwise
from dask.delayed import Delayed
from dask.typing import Graph
from dask.utils import filetexts, tmpdir, tmpfile
from dask.utils_test import add, hlg_layer, hlg_layer_topological, inc

dsk = {("x", 0): (range, 5), ("x", 1): (range, 5), ("x", 2): (range, 5)}
dsk: Graph = {("x", 0): (range, 5), ("x", 1): (range, 5), ("x", 2): (range, 5)}

L = list(range(5)) * 3

Expand Down
45 changes: 36 additions & 9 deletions dask/core.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

from collections import defaultdict
from collections.abc import Collection, Iterable
from typing import Any, cast
from collections.abc import Collection, Iterable, Mapping
from typing import Any, Literal, TypeVar, cast, overload

from dask.typing import Key, no_default
from dask.typing import Graph, Key, NoDefault, no_default


def ishashable(x):
Expand Down Expand Up @@ -223,7 +223,32 @@ def validate_key(key: object) -> None:
raise TypeError(f"Unexpected key type {type(key)} (value: {key!r})")


def get_dependencies(dsk, key=None, task=no_default, as_list=False):
@overload
def get_dependencies(
dsk: Graph,
key: Key | None = ...,
task: Key | NoDefault = ...,
as_list: Literal[False] = ...,
) -> set[Key]:
...


@overload
def get_dependencies(
dsk: Graph,
key: Key | None,
task: Key | NoDefault,
as_list: Literal[True],
) -> list[Key]:
...


def get_dependencies(
dsk: Graph,
key: Key | None = None,
task: Key | NoDefault = no_default,
as_list: bool = False,
) -> set[Key] | list[Key]:
"""Get the immediate tasks on which this task depends
Examples
Expand Down Expand Up @@ -264,7 +289,7 @@ def get_dependencies(dsk, key=None, task=no_default, as_list=False):
return keys_in_tasks(dsk, [arg], as_list=as_list)


def get_deps(dsk):
def get_deps(dsk: Graph) -> tuple[dict[Key, set[Key]], dict[Key, set[Key]]]:
"""Get dependencies and dependents from dask dask graph
>>> inc = lambda x: x + 1
Expand Down Expand Up @@ -308,22 +333,24 @@ def flatten(seq, container=list):
yield item


def reverse_dict(d):
T_ = TypeVar("T_")


def reverse_dict(d: Mapping[T_, Iterable[T_]]) -> dict[T_, set[T_]]:
"""
>>> a, b, c = 'abc'
>>> d = {a: [b, c], b: [c]}
>>> reverse_dict(d) # doctest: +SKIP
{'a': set([]), 'b': set(['a']}, 'c': set(['a', 'b'])}
"""
result = defaultdict(set)
result: defaultdict[T_, set[T_]] = defaultdict(set)
_add = set.add
for k, vals in d.items():
result[k]
for val in vals:
_add(result[val], k)
result.default_factory = None
return result
return dict(result)


def subs(task, key, val):
Expand Down
8 changes: 5 additions & 3 deletions dask/dataframe/io/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,10 +486,12 @@ def read_pandas(
lineterminator = "\n"
if include_path_column and isinstance(include_path_column, bool):
include_path_column = "path"
if "index" in kwargs or "index_col" in kwargs:
if "index" in kwargs or (
"index_col" in kwargs and kwargs.get("index_col") is not False
):
raise ValueError(
"Keywords 'index' and 'index_col' not supported. "
f"Use dd.{reader_name}(...).set_index('my-index') instead"
"Keywords 'index' and 'index_col' not supported, except for "
"'index_col=False'. Use dd.{reader_name}(...).set_index('my-index') instead"
)
for kw in ["iterator", "chunksize"]:
if kw in kwargs:
Expand Down
4 changes: 4 additions & 0 deletions dask/dataframe/io/tests/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,6 +1117,10 @@ def test_index_col():
except ValueError as e:
assert "set_index" in str(e)

df = pd.read_csv(fn, index_col=False)
ddf = dd.read_csv(fn, blocksize=30, index_col=False)
assert_eq(df, ddf, check_index=False)


def test_read_csv_with_datetime_index_partitions_one():
with filetext(timeseries) as fn:
Expand Down
10 changes: 5 additions & 5 deletions dask/graph_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from __future__ import annotations

import uuid
from collections.abc import Callable, Hashable, Set
from typing import Any, Literal, TypeVar
from collections.abc import Callable, Hashable
from typing import Literal, TypeVar

from dask.base import (
clone_key,
Expand All @@ -20,7 +20,7 @@
from dask.core import flatten
from dask.delayed import Delayed, delayed
from dask.highlevelgraph import HighLevelGraph, Layer, MaterializedLayer
from dask.typing import Key
from dask.typing import Graph, Key

__all__ = ("bind", "checkpoint", "clone", "wait_on")

Expand Down Expand Up @@ -78,7 +78,7 @@ def _checkpoint_one(collection, split_every) -> Delayed:
next(keys_iter)
except StopIteration:
# Collection has 0 or 1 keys; no need for a map step
layer = {name: (chunks.checkpoint, collection.__dask_keys__())}
layer: Graph = {name: (chunks.checkpoint, collection.__dask_keys__())}
dsk = HighLevelGraph.from_collections(name, layer, dependencies=(collection,))
return Delayed(name, dsk)

Expand Down Expand Up @@ -321,7 +321,7 @@ def _bind_one(

dsk = child.__dask_graph__() # type: ignore
new_layers: dict[str, Layer] = {}
new_deps: dict[str, Set[Any]] = {}
new_deps: dict[str, set[str]] = {}

if isinstance(dsk, HighLevelGraph):
try:
Expand Down
Loading

0 comments on commit ec83b51

Please sign in to comment.