Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: debugging slicing on gpu #18

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions src/awkward/_slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,19 @@

SliceItem: TypeAlias = "int | slice | str | None | Ellipsis | ArrayLike | Content"

import functools

def trace_function_calls(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
print(f"_slicing.py: Calling function: {func.__name__}")
result = func(*args, **kwargs)
print(f"_slicing.py: Function {func.__name__} returned {result}")
return result
return wrapper


@trace_function_calls
def normalize_slice(slice_: slice, *, nplike: NumpyLike) -> slice:
"""
Args:
Expand Down Expand Up @@ -59,13 +71,15 @@ def __repr__(self):
S = TypeVar("S", bound=Sequence)


@trace_function_calls
def head_tail(sequence: S[T]) -> tuple[T | type(NO_HEAD), S[T]]:
if len(sequence) == 0:
return NO_HEAD, ()
else:
return sequence[0], sequence[1:]


@trace_function_calls
def prepare_advanced_indexing(items, backend: Backend):
"""Broadcast index objects to satisfy NumPy indexing rules

Expand Down Expand Up @@ -177,7 +191,7 @@ def prepare_advanced_indexing(items, backend: Backend):
)
return tuple(prepared)


@trace_function_calls
def normalize_integer_like(x) -> int | ArrayLike:
if is_array_like(x):
if np.issubdtype(x.dtype, np.integer) and x.ndim == 0:
Expand All @@ -187,7 +201,7 @@ def normalize_integer_like(x) -> int | ArrayLike:
else:
return int(x)


@trace_function_calls
def normalise_item(item, backend: Backend) -> SliceItem:
"""
Args:
Expand All @@ -200,9 +214,11 @@ def normalise_item(item, backend: Backend) -> SliceItem:
"""
# Basic indices
if is_integer_like(item):
print(" normalize_item::integer_like", item)
return normalize_integer_like(item)

elif isinstance(item, slice):
print(" normalize_item::slice", item)
return normalize_slice(item, nplike=backend.index_nplike)

elif isinstance(item, str):
Expand Down Expand Up @@ -300,12 +316,14 @@ def normalise_item(item, backend: Backend) -> SliceItem:
+ repr(item).replace("\n", "\n ")
)


@trace_function_calls
def normalise_items(where: Sequence, backend: Backend) -> list:
# First prepare items for broadcasting into like-types
for x in where:
print(" normalise_items", x, where)
return [normalise_item(x, backend=backend) for x in where]


@trace_function_calls
def _normalise_item_RegularArray_to_ListOffsetArray64(item: Content) -> Content:
if isinstance(item, ak.contents.RegularArray):
next = item.to_ListOffsetArray64()
Expand All @@ -321,7 +339,7 @@ def _normalise_item_RegularArray_to_ListOffsetArray64(item: Content) -> Content:
else:
raise AssertionError(type(item))


@trace_function_calls
def _normalise_item_nested(item: Content) -> Content:
if isinstance(item, ak.contents.EmptyArray):
# policy: unknown -> int
Expand Down Expand Up @@ -460,7 +478,7 @@ def _normalise_item_nested(item: Content) -> Content:
+ repr(item).replace("\n", "\n ")
)


@trace_function_calls
def _normalise_item_bool_to_int(item: Content, backend: Backend) -> Content:
"""
Args:
Expand Down Expand Up @@ -650,7 +668,7 @@ def _normalise_item_bool_to_int(item: Content, backend: Backend) -> Content:
else:
raise AssertionError(type(item))


@trace_function_calls
def getitem_next_array_wrap(
outcontent: Content, shape: tuple[int], outer_length: int = 0
) -> Content:
Expand Down
48 changes: 48 additions & 0 deletions src/awkward/contents/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,16 @@
JSONValueType: TypeAlias = """
float | int | str | list[JSONValueType] | dict[str, JSONValueType]
"""
import functools

def trace_function_calls(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
print(f"content.py: Calling function: {func.__name__}")
result = func(*args, **kwargs)
print(f"content.py: Function {func.__name__} returned {result}")
return result
return wrapper

class ImplementsApplyAction(Protocol):
def __call__(
Expand Down Expand Up @@ -298,6 +307,7 @@ def __iter__(self):
for i in range(len(self)):
yield self._getitem_at(i)

@trace_function_calls
def _getitem_next_field(
self,
head: SliceItem | tuple,
Expand All @@ -307,6 +317,7 @@ def _getitem_next_field(
nexthead, nexttail = ak._slicing.head_tail(tail)
return self._getitem_field(head)._getitem_next(nexthead, nexttail, advanced)

@trace_function_calls
def _getitem_next_fields(
self, head: SliceItem, tail: tuple[SliceItem, ...], advanced: Index | None
) -> Content:
Expand All @@ -321,6 +332,8 @@ def _getitem_next_fields(
nexthead, nexttail, advanced
)


@trace_function_calls
def _getitem_next_newaxis(
self, tail: tuple[SliceItem, ...], advanced: Index | None
) -> RegularArray:
Expand All @@ -329,6 +342,7 @@ def _getitem_next_newaxis(
self._getitem_next(nexthead, nexttail, advanced), 1, 0, parameters=None
)

@trace_function_calls
def _getitem_next_ellipsis(
self, tail: tuple[SliceItem, ...], advanced: Index | None
) -> Content:
Expand All @@ -353,6 +367,7 @@ def _getitem_next_ellipsis(
else:
return self._getitem_next(slice(None), (Ellipsis, *tail), advanced)

@trace_function_calls
def _getitem_next_regular_missing(
self,
head: IndexedOptionArray,
Expand Down Expand Up @@ -395,6 +410,7 @@ def _getitem_next_regular_missing(
out, indexlength, 1, parameters=self._parameters
)

@trace_function_calls
def _getitem_next_missing_jagged(
self, head: Content, tail, advanced: Index | None, that: Content
) -> RegularArray:
Expand Down Expand Up @@ -447,6 +463,7 @@ def _getitem_next_missing_jagged(
out, index.length, 1, parameters=self._parameters
)

@trace_function_calls
def _getitem_next_missing(
self,
head: IndexedOptionArray,
Expand Down Expand Up @@ -508,71 +525,92 @@ def _getitem_next_missing(
f"FIXME: unhandled case of SliceMissing with {nextcontent}"
)

@trace_function_calls
def __getitem__(self, where):
print(" content::__getitem__ of ", self, self.backend)
return self._getitem(where)

@trace_function_calls
def _getitem(self, where):
if is_integer_like(where):
print(" in content::_getitem: is integer like!", where)
return self._getitem_at(ak._slicing.normalize_integer_like(where))

elif isinstance(where, slice) and where.step is None:
print(" content::slice with step???")
# Ensure that start, stop are non-negative!
start, stop, _, _ = self._backend.index_nplike.derive_slice_for_length(
normalize_slice(where, nplike=self._backend.index_nplike), self.length
)
print(" >>>start, stop", start, stop)
return self._getitem_range(start, stop)

elif isinstance(where, slice):
print(" content::slice???")
return self._getitem((where,))

elif isinstance(where, str):
print(" content::str")
return self._getitem_field(where)

elif where is np.newaxis:
print(" content::axis")
return self._getitem((where,))

elif where is Ellipsis:
print(" content::Ellipsis")
return self._getitem((where,))

elif isinstance(where, tuple):
print(" content::tuple", where)
if len(where) == 0:
print(" len 0")
return self

# Backend may change if index contains typetracers
backend = backend_of(self, *where, coerce_to_common=True)
this = self.to_backend(backend)
print(" this", this)

# Normalise valid indices onto well-defined basis
items = ak._slicing.normalise_items(where, backend)
print(" items", items)
# Prepare items for advanced indexing (e.g. via broadcasting)
nextwhere = ak._slicing.prepare_advanced_indexing(items, backend)
print(" nextwhere", nextwhere)

next = ak.contents.RegularArray(
this,
this.length,
1,
parameters=None,
)
print(" next", next)


out = next._getitem_next(nextwhere[0], nextwhere[1:], None)
print(" content::tuple out", out)

if out.length is not unknown_length and out.length == 0:
return out._getitem_nothing()
else:
return out._getitem_at(0)

elif isinstance(where, ak.highlevel.Array):
print(" content::Array")
return self._getitem(where.layout)

# Convert between nplikes of different backends
elif (
isinstance(where, ak.contents.Content)
and where.backend is not self._backend
):
print(" content::backends")
backend = backend_of(self, where, coerce_to_common=True)
return self.to_backend(backend)._getitem(where.to_backend(backend))

elif isinstance(where, ak.contents.NumpyArray):
print(" content::NumpyArray")
data_as_index = to_nplike(
where.data,
self._backend.index_nplike,
Expand Down Expand Up @@ -693,25 +731,31 @@ def _getitem(self, where):
+ repr(where).replace("\n", "\n ")
)

@trace_function_calls
def _is_getitem_at_placeholder(self) -> bool:
raise NotImplementedError

@trace_function_calls
def _getitem_at(self, where: IndexType):
raise NotImplementedError

@trace_function_calls
def _getitem_range(self, start: IndexType, stop: IndexType) -> Content:
raise NotImplementedError

@trace_function_calls
def _getitem_field(
self, where: str | SupportsIndex, only_fields: tuple[str, ...] = ()
) -> Content:
raise NotImplementedError

@trace_function_calls
def _getitem_fields(
self, where: list[str], only_fields: tuple[str, ...] = ()
) -> Content:
raise NotImplementedError

@trace_function_calls
def _getitem_next(
self,
head: SliceItem | tuple,
Expand All @@ -720,6 +764,7 @@ def _getitem_next(
) -> Content:
raise NotImplementedError

@trace_function_calls
def _getitem_next_jagged(
self,
slicestarts: Index,
Expand All @@ -729,9 +774,11 @@ def _getitem_next_jagged(
) -> Content:
raise NotImplementedError

@trace_function_calls
def _carry(self, carry: Index, allow_lazy: bool) -> Content:
raise NotImplementedError

@trace_function_calls
def _local_index_axis0(self) -> NumpyArray:
localindex = Index64.empty(self.length, self._backend.index_nplike)
self._backend.maybe_kernel_error(
Expand All @@ -744,6 +791,7 @@ def _local_index_axis0(self) -> NumpyArray:
localindex.data, parameters=None, backend=self._backend
)

@trace_function_calls
def _mergeable_next(self, other: Content, mergebool: bool) -> bool:
raise NotImplementedError

Expand Down
Loading
Loading