Skip to content

Commit

Permalink
make _ListMerger a function
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Sep 10, 2022
1 parent ffd77b5 commit ceaf8b2
Showing 1 changed file with 112 additions and 106 deletions.
218 changes: 112 additions & 106 deletions boxtree/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from pyopencl.elementwise import ElementwiseTemplate, ElementwiseKernel

from arraycontext import Array
from pytools import ProcessLogger, log_process, memoize_method
from pytools import ProcessLogger, log_process, memoize_method, memoize_in
from pytools.obj_array import make_obj_array
from mako.template import Template

Expand Down Expand Up @@ -1180,114 +1180,124 @@ class _IndexStyle(enum.IntEnum):


class _ListMerger:
"""Utility class for combining box lists optionally changing indexing style."""

def __init__(self, array_context: PyOpenCLArrayContext, box_id_dtype):
self._setup_actx = array_context
self.box_id_dtype = box_id_dtype

@property
def context(self):
return self._setup_actx.queue.context
def __call__(self, actx, input_starts, input_lists, input_index_style,
output_index_style, target_boxes, target_or_target_parent_boxes,
nboxes, debug=False, wait_for=None):
from warnings import warn
warn(f"'{type(self).__name__}' is deprecated and will be removed in 2023. "
"Use 'merge_lists' instead.",
DeprecationWarning, stacklevel=2)

return merge_lists(
actx, input_starts, input_lists,
input_index_style, output_index_style,
target_boxes, target_or_target_parent_boxes, nboxes, self.box_id_dtype,
debug=debug)


def merge_lists(
actx: PyOpenCLArrayContext, input_starts, input_lists, input_index_style,
output_index_style, target_boxes, target_or_target_parent_boxes,
nboxes, box_id_dtype, debug: bool = False):
"""Utility class for combining box lists optionally changing indexing style.
:arg input_starts: Starts arrays of input
:arg input_lists: Lists arrays of input
:arg input_index_style: A :class:`_IndexStyle`
:arg output_index_style: A :class:`_IndexStyle`
"""
# {{{

if (
output_index_style == _IndexStyle.TARGET_OR_TARGET_PARENT_BOXES
and input_index_style == _IndexStyle.TARGET_BOXES):
raise ValueError(
"unsupported: merging a list indexed by target boxes "
"into a list indexed by target or target parent boxes")

ntarget_boxes = len(target_boxes)
ntarget_or_ntarget_parent_boxes = len(target_or_target_parent_boxes)

noutput_boxes = (ntarget_boxes
if output_index_style == _IndexStyle.TARGET_BOXES
else ntarget_or_ntarget_parent_boxes)

if (
input_index_style == _IndexStyle.TARGET_OR_TARGET_PARENT_BOXES
and output_index_style == _IndexStyle.TARGET_BOXES):
from boxtree.tools import reverse_index_array
target_or_target_parent_boxes_from_all_boxes = reverse_index_array(
actx, target_or_target_parent_boxes, target_size=nboxes)
target_or_target_parent_boxes_from_target_boxes = (
target_or_target_parent_boxes_from_all_boxes[target_boxes])

output_to_input_box = target_or_target_parent_boxes_from_target_boxes
else:
output_to_input_box = actx.from_numpy(
np.arange(noutput_boxes, dtype=box_id_dtype))

new_counts = actx.empty(noutput_boxes + 1, box_id_dtype)
assert len(input_starts) == len(input_lists)

nlists = len(input_starts)
assert nlists >= 1

@memoize_method
def get_list_merger_kernel(self, nlists, write_counts):
"""
:arg nlists: Number of input lists
:arg write_counts: A :class:`bool`, indicating whether to generate a
kernel that produces box counts or box lists
"""
assert nlists >= 1
# }}}

# {{{ merge lists

@memoize_in(actx, (merge_lists, box_id_dtype, nlists))
def get_list_merger_kernel(with_write_counts):
return LIST_MERGER_TEMPLATE.build(
self.context,
actx.context,
type_aliases=(
("box_id_t", self.box_id_dtype),
("box_id_t", box_id_dtype),
),
var_values=(
("nlists", nlists),
("write_counts", write_counts),
("write_counts", with_write_counts),
))

def __call__(self, actx, input_starts, input_lists, input_index_style,
output_index_style, target_boxes, target_or_target_parent_boxes,
nboxes, debug=False, wait_for=None):
"""
:arg input_starts: Starts arrays of input
:arg input_lists: Lists arrays of input
:arg input_index_style: A :class:`_IndexStyle`
:arg output_index_style: A :class:`_IndexStyle`
:returns: A pair *results_dict, event*, where *results_dict*
contains entries *starts* and *lists*
"""
if wait_for is None:
wait_for = []

if (
output_index_style == _IndexStyle.TARGET_OR_TARGET_PARENT_BOXES
and input_index_style == _IndexStyle.TARGET_BOXES):
raise ValueError(
"unsupported: merging a list indexed by target boxes "
"into a list indexed by target or target parent boxes")

ntarget_boxes = len(target_boxes)
ntarget_or_ntarget_parent_boxes = len(target_or_target_parent_boxes)

noutput_boxes = (ntarget_boxes
if output_index_style == _IndexStyle.TARGET_BOXES
else ntarget_or_ntarget_parent_boxes)

if (
input_index_style == _IndexStyle.TARGET_OR_TARGET_PARENT_BOXES
and output_index_style == _IndexStyle.TARGET_BOXES):
from boxtree.tools import reverse_index_array
target_or_target_parent_boxes_from_all_boxes = reverse_index_array(
actx, target_or_target_parent_boxes, target_size=nboxes)
target_or_target_parent_boxes_from_target_boxes = (
target_or_target_parent_boxes_from_all_boxes[target_boxes]
)

output_to_input_box = target_or_target_parent_boxes_from_target_boxes
else:
output_to_input_box = actx.from_numpy(
np.arange(noutput_boxes, dtype=self.box_id_dtype)
)

new_counts = actx.empty(noutput_boxes + 1, self.box_id_dtype)

assert len(input_starts) == len(input_lists)
nlists = len(input_starts)

evt = self.get_list_merger_kernel(nlists, True)(*(
# input:
(output_to_input_box,)
+ input_starts
# output:
+ (new_counts,)),
range=slice(noutput_boxes),
queue=actx.queue,
wait_for=wait_for)

import pyopencl.array as cl_array
new_starts = cl_array.cumsum(new_counts)
del new_counts
evt = get_list_merger_kernel(True)(*(
# input:
(output_to_input_box,)
+ input_starts
# output:
+ (new_counts,)),
range=slice(noutput_boxes),
queue=actx.queue,
)
new_counts.add_event(evt)

import pyopencl.array as cl_array
new_starts = cl_array.cumsum(new_counts)
del new_counts

new_lists = actx.empty(int(actx.to_numpy(new_starts[-1])), box_id_dtype)
new_lists.fill(999999999)

evt = get_list_merger_kernel(False)(*(
# input:
(output_to_input_box,)
+ input_starts
+ input_lists
+ (new_starts,)
# output:
+ (new_lists,)),
range=slice(noutput_boxes),
queue=actx.queue,
)
new_starts.add_event(evt)
new_lists.add_event(evt)

new_lists = actx.empty(int(actx.to_numpy(new_starts[-1])), self.box_id_dtype)
new_lists.fill(999999999)
# }}}

evt = self.get_list_merger_kernel(nlists, False)(*(
# input:
(output_to_input_box,)
+ input_starts
+ input_lists
+ (new_starts,)
# output:
+ (new_lists,)),
range=slice(noutput_boxes),
queue=actx.queue,
wait_for=[evt])
return {"starts": new_starts, "lists": new_lists}

return dict(starts=new_starts, lists=new_lists), evt

# }}}

Expand Down Expand Up @@ -1592,18 +1602,18 @@ def ntarget_or_target_parent_boxes(self):

# {{{ "close" list merging -> "unified list 1"

def merge_close_lists(self, actx, debug=False):
def merge_close_lists(self,
actx: PyOpenCLArrayContext,
debug: bool = False) -> "FMMTraversalInfo":
"""Return a new :class:`FMMTraversalInfo` instance with the contents of
:attr:`from_sep_close_smaller_starts` and
:attr:`from_sep_close_bigger_starts` merged into
:attr:`neighbor_source_boxes_starts` and these two attributes set to
*None*.
"""

list_merger = _ListMerger(actx, self.tree.box_id_dtype)

result, evt = (
list_merger(
result = (
merge_lists(
actx,
# starts
(self.neighbor_source_boxes_starts,
Expand All @@ -1621,11 +1631,9 @@ def merge_close_lists(self, actx, debug=False):
self.target_boxes,
self.target_or_target_parent_boxes,
self.tree.nboxes,
self.tree.box_id_dtype,
debug))

import pyopencl as cl
cl.wait_for_events([evt])

from dataclasses import replace
return replace(self,
neighbor_source_boxes_starts=actx.freeze(result["starts"]),
Expand Down Expand Up @@ -2159,8 +2167,7 @@ def extract_level_start_box_nrs(box_list, wait_for):
from_sep_close_bigger_starts_raw = result["from_sep_close_bigger"].starts
from_sep_close_bigger_lists_raw = result["from_sep_close_bigger"].lists

list_merger = _ListMerger(actx, tree.box_id_dtype)
result, evt = list_merger(
result = merge_lists(
actx,
# starts
(from_sep_close_bigger_starts_raw,),
Expand All @@ -2174,10 +2181,9 @@ def extract_level_start_box_nrs(box_list, wait_for):
target_boxes,
target_or_target_parent_boxes,
tree.nboxes,
tree.box_id_dtype,
debug,
wait_for=wait_for)

wait_for = [evt]
)

del from_sep_close_bigger_starts_raw
del from_sep_close_bigger_lists_raw
Expand Down

0 comments on commit ceaf8b2

Please sign in to comment.