From ceaf8b296461e65612cd34a568e5346460d380a9 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Sat, 10 Sep 2022 23:06:07 +0300 Subject: [PATCH] make _ListMerger a function --- boxtree/traversal.py | 218 ++++++++++++++++++++++--------------------- 1 file changed, 112 insertions(+), 106 deletions(-) diff --git a/boxtree/traversal.py b/boxtree/traversal.py index 47c2d5c1..faa514dc 100644 --- a/boxtree/traversal.py +++ b/boxtree/traversal.py @@ -43,7 +43,7 @@ from pyopencl.elementwise import ElementwiseTemplate, ElementwiseKernel from arraycontext import Array -from pytools import ProcessLogger, log_process, memoize_method +from pytools import ProcessLogger, log_process, memoize_method, memoize_in from pytools.obj_array import make_obj_array from mako.template import Template @@ -1180,114 +1180,124 @@ class _IndexStyle(enum.IntEnum): class _ListMerger: - """Utility class for combining box lists optionally changing indexing style.""" - def __init__(self, array_context: PyOpenCLArrayContext, box_id_dtype): self._setup_actx = array_context self.box_id_dtype = box_id_dtype - @property - def context(self): - return self._setup_actx.queue.context + def __call__(self, actx, input_starts, input_lists, input_index_style, + output_index_style, target_boxes, target_or_target_parent_boxes, + nboxes, debug=False, wait_for=None): + from warnings import warn + warn(f"'{type(self).__name__}' is deprecated and will be removed in 2023. " + "Use 'merge_lists' instead.", + DeprecationWarning, stacklevel=2) + + return merge_lists( + actx, input_starts, input_lists, + input_index_style, output_index_style, + target_boxes, target_or_target_parent_boxes, nboxes, self.box_id_dtype, + debug=debug) + + +def merge_lists( + actx: PyOpenCLArrayContext, input_starts, input_lists, input_index_style, + output_index_style, target_boxes, target_or_target_parent_boxes, + nboxes, box_id_dtype, debug: bool = False): + """Utility class for combining box lists optionally changing indexing style. + + :arg input_starts: Starts arrays of input + :arg input_lists: Lists arrays of input + :arg input_index_style: A :class:`_IndexStyle` + :arg output_index_style: A :class:`_IndexStyle` + """ + # {{{ + + if ( + output_index_style == _IndexStyle.TARGET_OR_TARGET_PARENT_BOXES + and input_index_style == _IndexStyle.TARGET_BOXES): + raise ValueError( + "unsupported: merging a list indexed by target boxes " + "into a list indexed by target or target parent boxes") + + ntarget_boxes = len(target_boxes) + ntarget_or_ntarget_parent_boxes = len(target_or_target_parent_boxes) + + noutput_boxes = (ntarget_boxes + if output_index_style == _IndexStyle.TARGET_BOXES + else ntarget_or_ntarget_parent_boxes) + + if ( + input_index_style == _IndexStyle.TARGET_OR_TARGET_PARENT_BOXES + and output_index_style == _IndexStyle.TARGET_BOXES): + from boxtree.tools import reverse_index_array + target_or_target_parent_boxes_from_all_boxes = reverse_index_array( + actx, target_or_target_parent_boxes, target_size=nboxes) + target_or_target_parent_boxes_from_target_boxes = ( + target_or_target_parent_boxes_from_all_boxes[target_boxes]) + + output_to_input_box = target_or_target_parent_boxes_from_target_boxes + else: + output_to_input_box = actx.from_numpy( + np.arange(noutput_boxes, dtype=box_id_dtype)) + + new_counts = actx.empty(noutput_boxes + 1, box_id_dtype) + assert len(input_starts) == len(input_lists) + + nlists = len(input_starts) + assert nlists >= 1 - @memoize_method - def get_list_merger_kernel(self, nlists, write_counts): - """ - :arg nlists: Number of input lists - :arg write_counts: A :class:`bool`, indicating whether to generate a - kernel that produces box counts or box lists - """ - assert nlists >= 1 + # }}} + # {{{ merge lists + + @memoize_in(actx, (merge_lists, box_id_dtype, nlists)) + def get_list_merger_kernel(with_write_counts): return LIST_MERGER_TEMPLATE.build( - self.context, + actx.context, type_aliases=( - ("box_id_t", self.box_id_dtype), + ("box_id_t", box_id_dtype), ), var_values=( ("nlists", nlists), - ("write_counts", write_counts), + ("write_counts", with_write_counts), )) - def __call__(self, actx, input_starts, input_lists, input_index_style, - output_index_style, target_boxes, target_or_target_parent_boxes, - nboxes, debug=False, wait_for=None): - """ - :arg input_starts: Starts arrays of input - :arg input_lists: Lists arrays of input - :arg input_index_style: A :class:`_IndexStyle` - :arg output_index_style: A :class:`_IndexStyle` - :returns: A pair *results_dict, event*, where *results_dict* - contains entries *starts* and *lists* - """ - if wait_for is None: - wait_for = [] - - if ( - output_index_style == _IndexStyle.TARGET_OR_TARGET_PARENT_BOXES - and input_index_style == _IndexStyle.TARGET_BOXES): - raise ValueError( - "unsupported: merging a list indexed by target boxes " - "into a list indexed by target or target parent boxes") - - ntarget_boxes = len(target_boxes) - ntarget_or_ntarget_parent_boxes = len(target_or_target_parent_boxes) - - noutput_boxes = (ntarget_boxes - if output_index_style == _IndexStyle.TARGET_BOXES - else ntarget_or_ntarget_parent_boxes) - - if ( - input_index_style == _IndexStyle.TARGET_OR_TARGET_PARENT_BOXES - and output_index_style == _IndexStyle.TARGET_BOXES): - from boxtree.tools import reverse_index_array - target_or_target_parent_boxes_from_all_boxes = reverse_index_array( - actx, target_or_target_parent_boxes, target_size=nboxes) - target_or_target_parent_boxes_from_target_boxes = ( - target_or_target_parent_boxes_from_all_boxes[target_boxes] - ) - - output_to_input_box = target_or_target_parent_boxes_from_target_boxes - else: - output_to_input_box = actx.from_numpy( - np.arange(noutput_boxes, dtype=self.box_id_dtype) - ) - - new_counts = actx.empty(noutput_boxes + 1, self.box_id_dtype) - - assert len(input_starts) == len(input_lists) - nlists = len(input_starts) - - evt = self.get_list_merger_kernel(nlists, True)(*( - # input: - (output_to_input_box,) - + input_starts - # output: - + (new_counts,)), - range=slice(noutput_boxes), - queue=actx.queue, - wait_for=wait_for) - - import pyopencl.array as cl_array - new_starts = cl_array.cumsum(new_counts) - del new_counts + evt = get_list_merger_kernel(True)(*( + # input: + (output_to_input_box,) + + input_starts + # output: + + (new_counts,)), + range=slice(noutput_boxes), + queue=actx.queue, + ) + new_counts.add_event(evt) + + import pyopencl.array as cl_array + new_starts = cl_array.cumsum(new_counts) + del new_counts + + new_lists = actx.empty(int(actx.to_numpy(new_starts[-1])), box_id_dtype) + new_lists.fill(999999999) + + evt = get_list_merger_kernel(False)(*( + # input: + (output_to_input_box,) + + input_starts + + input_lists + + (new_starts,) + # output: + + (new_lists,)), + range=slice(noutput_boxes), + queue=actx.queue, + ) + new_starts.add_event(evt) + new_lists.add_event(evt) - new_lists = actx.empty(int(actx.to_numpy(new_starts[-1])), self.box_id_dtype) - new_lists.fill(999999999) + # }}} - evt = self.get_list_merger_kernel(nlists, False)(*( - # input: - (output_to_input_box,) - + input_starts - + input_lists - + (new_starts,) - # output: - + (new_lists,)), - range=slice(noutput_boxes), - queue=actx.queue, - wait_for=[evt]) + return {"starts": new_starts, "lists": new_lists} - return dict(starts=new_starts, lists=new_lists), evt # }}} @@ -1592,7 +1602,9 @@ def ntarget_or_target_parent_boxes(self): # {{{ "close" list merging -> "unified list 1" - def merge_close_lists(self, actx, debug=False): + def merge_close_lists(self, + actx: PyOpenCLArrayContext, + debug: bool = False) -> "FMMTraversalInfo": """Return a new :class:`FMMTraversalInfo` instance with the contents of :attr:`from_sep_close_smaller_starts` and :attr:`from_sep_close_bigger_starts` merged into @@ -1600,10 +1612,8 @@ def merge_close_lists(self, actx, debug=False): *None*. """ - list_merger = _ListMerger(actx, self.tree.box_id_dtype) - - result, evt = ( - list_merger( + result = ( + merge_lists( actx, # starts (self.neighbor_source_boxes_starts, @@ -1621,11 +1631,9 @@ def merge_close_lists(self, actx, debug=False): self.target_boxes, self.target_or_target_parent_boxes, self.tree.nboxes, + self.tree.box_id_dtype, debug)) - import pyopencl as cl - cl.wait_for_events([evt]) - from dataclasses import replace return replace(self, neighbor_source_boxes_starts=actx.freeze(result["starts"]), @@ -2159,8 +2167,7 @@ def extract_level_start_box_nrs(box_list, wait_for): from_sep_close_bigger_starts_raw = result["from_sep_close_bigger"].starts from_sep_close_bigger_lists_raw = result["from_sep_close_bigger"].lists - list_merger = _ListMerger(actx, tree.box_id_dtype) - result, evt = list_merger( + result = merge_lists( actx, # starts (from_sep_close_bigger_starts_raw,), @@ -2174,10 +2181,9 @@ def extract_level_start_box_nrs(box_list, wait_for): target_boxes, target_or_target_parent_boxes, tree.nboxes, + tree.box_id_dtype, debug, - wait_for=wait_for) - - wait_for = [evt] + ) del from_sep_close_bigger_starts_raw del from_sep_close_bigger_lists_raw