diff --git a/boxtree/constant_one.py b/boxtree/constant_one.py index 4e61f736..958989a1 100644 --- a/boxtree/constant_one.py +++ b/boxtree/constant_one.py @@ -26,6 +26,8 @@ """ import numpy as np + +from boxtree.array_context import PyOpenCLArrayContext from boxtree.fmm import TreeIndependentDataForWrangler, ExpansionWranglerInterface from boxtree.timing import DummyTimingFuture @@ -83,7 +85,9 @@ def local_expansions_view(self, local_exps, level): def timing_future(ops): return DummyTimingFuture.from_op_count(ops) - def form_multipoles(self, level_start_source_box_nrs, source_boxes, + def form_multipoles(self, actx: PyOpenCLArrayContext, + level_start_source_box_nrs, + source_boxes, src_weight_vecs): src_weights, = src_weight_vecs mpoles = self.multipole_expansion_zeros() @@ -96,8 +100,10 @@ def form_multipoles(self, level_start_source_box_nrs, source_boxes, return mpoles, self.timing_future(ops) - def coarsen_multipoles(self, level_start_source_parent_box_nrs, - source_parent_boxes, mpoles): + def coarsen_multipoles(self, actx: PyOpenCLArrayContext, + level_start_source_parent_box_nrs, + source_parent_boxes, + mpoles): tree = self.tree ops = 0 @@ -119,7 +125,8 @@ def coarsen_multipoles(self, level_start_source_parent_box_nrs, return mpoles, self.timing_future(ops) - def eval_direct(self, target_boxes, neighbor_sources_starts, + def eval_direct(self, actx: PyOpenCLArrayContext, + target_boxes, neighbor_sources_starts, neighbor_sources_lists, src_weight_vecs): src_weights, = src_weight_vecs pot = self.output_zeros() @@ -144,6 +151,7 @@ def eval_direct(self, target_boxes, neighbor_sources_starts, return pot, self.timing_future(ops) def multipole_to_local(self, + actx: PyOpenCLArrayContext, level_start_target_or_target_parent_box_nrs, target_or_target_parent_boxes, starts, lists, mpole_exps): @@ -164,7 +172,9 @@ def multipole_to_local(self, return local_exps, self.timing_future(ops) def eval_multipoles(self, - target_boxes_by_source_level, from_sep_smaller_nonsiblings_by_level, + actx: PyOpenCLArrayContext, + target_boxes_by_source_level, + from_sep_smaller_nonsiblings_by_level, mpole_exps): pot = self.output_zeros() ops = 0 @@ -186,8 +196,10 @@ def eval_multipoles(self, return pot, self.timing_future(ops) def form_locals(self, + actx: PyOpenCLArrayContext, level_start_target_or_target_parent_box_nrs, - target_or_target_parent_boxes, starts, lists, src_weight_vecs): + target_or_target_parent_boxes, + starts, lists, src_weight_vecs): src_weights, = src_weight_vecs local_exps = self.local_expansion_zeros() ops = 0 @@ -209,7 +221,9 @@ def form_locals(self, return local_exps, self.timing_future(ops) - def refine_locals(self, level_start_target_or_target_parent_box_nrs, + def refine_locals(self, + actx: PyOpenCLArrayContext, + level_start_target_or_target_parent_box_nrs, target_or_target_parent_boxes, local_exps): ops = 0 @@ -222,7 +236,10 @@ def refine_locals(self, level_start_target_or_target_parent_box_nrs, return local_exps, self.timing_future(ops) - def eval_locals(self, level_start_target_box_nrs, target_boxes, local_exps): + def eval_locals(self, + actx: PyOpenCLArrayContext, + level_start_target_box_nrs, + target_boxes, local_exps): pot = self.output_zeros() ops = 0 @@ -233,7 +250,7 @@ def eval_locals(self, level_start_target_box_nrs, target_boxes, local_exps): return pot, self.timing_future(ops) - def finalize_potentials(self, potentials, template_ary): + def finalize_potentials(self, actx: PyOpenCLArrayContext, potentials): return potentials # }}} diff --git a/boxtree/distributed/__init__.py b/boxtree/distributed/__init__.py index 5ee39987..7ebd2d13 100644 --- a/boxtree/distributed/__init__.py +++ b/boxtree/distributed/__init__.py @@ -289,11 +289,14 @@ def __init__(self, array_context: PyOpenCLArrayContext, global_tree, array_context, global_tree, traversal_builder, wrangler_factory, calibration_params, comm) - def drive_dfmm(self, source_weights, timing_data=None): - """Calculate potentials at target points. - """ + def drive_dfmm(self, + actx: PyOpenCLArrayContext, + source_weights, + timing_data=None): + """Calculate potentials at target points.""" from boxtree.fmm import drive_fmm return drive_fmm( + actx, self.wrangler, source_weights, timing_data=timing_data, global_src_idx_all_ranks=self.src_idx_all_ranks, diff --git a/boxtree/fmm.py b/boxtree/fmm.py index d26b0939..2b8992cf 100644 --- a/boxtree/fmm.py +++ b/boxtree/fmm.py @@ -33,6 +33,7 @@ from boxtree.tree import Tree from boxtree.traversal import FMMTraversalInfo +from boxtree.array_context import PyOpenCLArrayContext import logging logger = logging.getLogger(__name__) @@ -155,6 +156,7 @@ def local_expansions_view(self, local_exps, level): @abstractmethod def form_multipoles(self, + actx: PyOpenCLArrayContext, level_start_source_box_nrs, source_boxes, src_weight_vecs): """Return an expansions array @@ -167,6 +169,7 @@ def form_multipoles(self, @abstractmethod def coarsen_multipoles(self, + actx: PyOpenCLArrayContext, level_start_source_parent_box_nrs, source_parent_boxes, mpoles): """For each box in *source_parent_boxes*, @@ -179,6 +182,7 @@ def coarsen_multipoles(self, @abstractmethod def eval_direct(self, + actx: PyOpenCLArrayContext, target_boxes, neighbor_sources_starts, neighbor_sources_lists, src_weight_vecs): """For each box in *target_boxes*, evaluate the influence of the @@ -191,6 +195,7 @@ def eval_direct(self, @abstractmethod def multipole_to_local(self, + actx: PyOpenCLArrayContext, level_start_target_or_target_parent_box_nrs, target_or_target_parent_boxes, starts, lists, mpole_exps): @@ -205,6 +210,7 @@ def multipole_to_local(self, @abstractmethod def eval_multipoles(self, + actx: PyOpenCLArrayContext, target_boxes_by_source_level, from_sep_smaller_by_level, mpole_exps): """For a level *i*, each box in *target_boxes_by_source_level[i]*, evaluate the multipole expansion in *mpole_exps* in the nearby boxes given in @@ -218,6 +224,7 @@ def eval_multipoles(self, @abstractmethod def form_locals(self, + actx: PyOpenCLArrayContext, level_start_target_or_target_parent_box_nrs, target_or_target_parent_boxes, starts, lists, src_weight_vecs): """For each box in *target_or_target_parent_boxes*, form local @@ -232,6 +239,7 @@ def form_locals(self, @abstractmethod def refine_locals(self, + actx: PyOpenCLArrayContext, level_start_target_or_target_parent_box_nrs, target_or_target_parent_boxes, local_exps): """For each box in *child_boxes*, @@ -243,6 +251,7 @@ def refine_locals(self, @abstractmethod def eval_locals(self, + actx: PyOpenCLArrayContext, level_start_target_box_nrs, target_boxes, local_exps): """For each box in *target_boxes*, evaluate the local expansion in *local_exps* and return a new potential array. @@ -254,7 +263,7 @@ def eval_locals(self, # }}} @abstractmethod - def finalize_potentials(self, potentials, template_ary): + def finalize_potentials(self, actx: PyOpenCLArrayContext, potentials): """ Postprocess the reordered potentials. This is where global scaling factors could be applied. This is distinct from :meth:`reorder_potentials` @@ -324,9 +333,12 @@ def communicate_mpoles(self, mpole_exps, return_stats=False): # }}} -def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs, +def drive_fmm(actx: PyOpenCLArrayContext, + wrangler: ExpansionWranglerInterface, + src_weight_vecs, *, timing_data=None, - global_src_idx_all_ranks=None, global_tgt_idx_all_ranks=None): + global_src_idx_all_ranks=None, + global_tgt_idx_all_ranks=None): """Top-level driver routine for a fast multipole calculation. In part, this is intended as a template for custom FMMs, in the sense that @@ -382,6 +394,7 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs, # {{{ "Step 2.1:" Construct local multipoles mpole_exps, timing_future = wrangler.form_multipoles( + actx, traversal.level_start_source_box_nrs, traversal.source_boxes, src_weight_vecs) @@ -393,6 +406,7 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs, # {{{ "Step 2.2:" Propagate multipoles upward mpole_exps, timing_future = wrangler.coarsen_multipoles( + actx, traversal.level_start_source_parent_box_nrs, traversal.source_parent_boxes, mpole_exps) @@ -408,6 +422,7 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs, # {{{ "Stage 3:" Direct evaluation from neighbor source boxes ("list 1") potentials, timing_future = wrangler.eval_direct( + actx, traversal.target_boxes, traversal.neighbor_source_boxes_starts, traversal.neighbor_source_boxes_lists, @@ -422,6 +437,7 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs, # {{{ "Stage 4:" translate separated siblings' ("list 2") mpoles to local local_exps, timing_future = wrangler.multipole_to_local( + actx, traversal.level_start_target_or_target_parent_box_nrs, traversal.target_or_target_parent_boxes, traversal.from_sep_siblings_starts, @@ -440,6 +456,7 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs, # contribution *out* of the downward-propagating local expansions) mpole_result, timing_future = wrangler.eval_multipoles( + actx, traversal.target_boxes_sep_smaller_by_source_level, traversal.from_sep_smaller_by_level, mpole_exps) @@ -455,6 +472,7 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs, "('list 3 close')") direct_result, timing_future = wrangler.eval_direct( + actx, traversal.target_boxes, traversal.from_sep_close_smaller_starts, traversal.from_sep_close_smaller_lists, @@ -469,6 +487,7 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs, # {{{ "Stage 6:" form locals for separated bigger source boxes ("list 4") local_result, timing_future = wrangler.form_locals( + actx, traversal.level_start_target_or_target_parent_box_nrs, traversal.target_or_target_parent_boxes, traversal.from_sep_bigger_starts, @@ -481,6 +500,7 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs, if traversal.from_sep_close_bigger_starts is not None: direct_result, timing_future = wrangler.eval_direct( + actx, traversal.target_boxes, traversal.from_sep_close_bigger_starts, traversal.from_sep_close_bigger_lists, @@ -495,6 +515,7 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs, # {{{ "Stage 7:" propagate local_exps downward local_exps, timing_future = wrangler.refine_locals( + actx, traversal.level_start_target_or_target_parent_box_nrs, traversal.target_or_target_parent_boxes, local_exps) @@ -506,6 +527,7 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs, # {{{ "Stage 8:" evaluate locals local_result, timing_future = wrangler.eval_locals( + actx, traversal.level_start_target_box_nrs, traversal.target_boxes, local_exps) @@ -521,7 +543,7 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs, result = wrangler.reorder_potentials(potentials) - result = wrangler.finalize_potentials(result, template_ary=src_weight_vecs[0]) + result = wrangler.finalize_potentials(actx, result) fmm_proc.done() diff --git a/boxtree/pyfmmlib_integration.py b/boxtree/pyfmmlib_integration.py index 4cb62171..95ddcbe8 100644 --- a/boxtree/pyfmmlib_integration.py +++ b/boxtree/pyfmmlib_integration.py @@ -684,7 +684,9 @@ def reorder_potentials(self, potentials): @log_process(logger) @return_timing_data - def form_multipoles(self, level_start_source_box_nrs, source_boxes, + def form_multipoles(self, actx: PyOpenCLArrayContext, + level_start_source_box_nrs, + source_boxes, src_weight_vecs): src_weights, = src_weight_vecs formmp = self.tree_indep.get_routine( @@ -727,8 +729,10 @@ def form_multipoles(self, level_start_source_box_nrs, source_boxes, @log_process(logger) @return_timing_data - def coarsen_multipoles(self, level_start_source_parent_box_nrs, - source_parent_boxes, mpoles): + def coarsen_multipoles(self, actx: PyOpenCLArrayContext, + level_start_source_parent_box_nrs, + source_parent_boxes, + mpoles): tree = self.tree mpmp = self.tree_indep.get_translation_routine(self, "%ddmpmp") @@ -783,8 +787,11 @@ def coarsen_multipoles(self, level_start_source_parent_box_nrs, @log_process(logger) @return_timing_data - def eval_direct(self, target_boxes, neighbor_sources_starts, - neighbor_sources_lists, src_weight_vecs): + def eval_direct(self, actx: PyOpenCLArrayContext, + target_boxes, + neighbor_sources_starts, + neighbor_sources_lists, + src_weight_vecs): src_weights, = src_weight_vecs output = self.output_zeros() @@ -827,7 +834,7 @@ def eval_direct(self, target_boxes, neighbor_sources_starts, @log_process(logger) @return_timing_data - def multipole_to_local(self, + def multipole_to_local(self, actx: PyOpenCLArrayContext, level_start_target_or_target_parent_box_nrs, target_or_target_parent_boxes, starts, lists, mpole_exps): @@ -942,8 +949,9 @@ def multipole_to_local(self, @log_process(logger) @return_timing_data - def eval_multipoles(self, - target_boxes_by_source_level, sep_smaller_nonsiblings_by_level, + def eval_multipoles(self, actx: PyOpenCLArrayContext, + target_boxes_by_source_level, + sep_smaller_nonsiblings_by_level, mpole_exps): output = self.output_zeros() @@ -985,9 +993,10 @@ def eval_multipoles(self, @log_process(logger) @return_timing_data - def form_locals(self, + def form_locals(self, actx: PyOpenCLArrayContext, level_start_target_or_target_parent_box_nrs, - target_or_target_parent_boxes, starts, lists, src_weight_vecs): + target_or_target_parent_boxes, + starts, lists, src_weight_vecs): src_weights, = src_weight_vecs local_exps = self.local_expansion_zeros() @@ -1065,8 +1074,10 @@ def form_locals(self, @log_process(logger) @return_timing_data - def refine_locals(self, level_start_target_or_target_parent_box_nrs, - target_or_target_parent_boxes, local_exps): + def refine_locals(self, actx: PyOpenCLArrayContext, + level_start_target_or_target_parent_box_nrs, + target_or_target_parent_boxes, + local_exps): locloc = self.tree_indep.get_translation_routine(self, "%ddlocloc") @@ -1112,7 +1123,10 @@ def refine_locals(self, level_start_target_or_target_parent_box_nrs, @log_process(logger) @return_timing_data - def eval_locals(self, level_start_target_box_nrs, target_boxes, local_exps): + def eval_locals(self, actx: PyOpenCLArrayContext, + level_start_target_box_nrs, + target_boxes, + local_exps): output = self.output_zeros() taeval = self.tree_indep.get_expn_eval_routine("ta") @@ -1147,7 +1161,7 @@ def eval_locals(self, level_start_target_box_nrs, target_boxes, local_exps): return output @log_process(logger) - def finalize_potentials(self, potential, template_ary): + def finalize_potentials(self, actx: PyOpenCLArrayContext, potential): if self.tree_indep.eqn_letter == "l" and self.dim == 2: scale_factor = -1/(2*np.pi) elif self.tree_indep.eqn_letter == "h" and self.dim == 2: diff --git a/boxtree/traversal.py b/boxtree/traversal.py index 7e777150..47c2d5c1 100644 --- a/boxtree/traversal.py +++ b/boxtree/traversal.py @@ -1701,7 +1701,8 @@ def get_kernel_info(self, dimensions, particle_id_dtype, box_id_dtype, sources_are_targets, sources_have_extent, targets_have_extent, extent_norm, source_boxes_has_mask, - source_parent_boxes_has_mask): + source_parent_boxes_has_mask, + debug=False): # {{{ process from_sep_smaller_crit @@ -1743,8 +1744,6 @@ def get_kernel_info(self, dimensions, particle_id_dtype, box_id_dtype, # }}} - debug = False - from pyopencl.tools import dtype_to_ctype from boxtree.tree import box_flags_enum diff --git a/examples/cost_model.py b/examples/cost_model.py index 94764afc..556508b9 100644 --- a/examples/cost_model.py +++ b/examples/cost_model.py @@ -73,7 +73,7 @@ def fmm_level_to_order(tree, ilevel): timing_data = {} from boxtree.fmm import drive_fmm src_weights = np.random.rand(tree.nsources).astype(tree.coord_dtype) - drive_fmm(wrangler, (src_weights,), timing_data=timing_data) + drive_fmm(actx, wrangler, (src_weights,), timing_data=timing_data) timing_results.append(timing_data) diff --git a/test/test_cost_model.py b/test/test_cost_model.py index e6c19229..8850c179 100644 --- a/test/test_cost_model.py +++ b/test/test_cost_model.py @@ -434,7 +434,7 @@ def fmm_level_to_order(tree, ilevel): timing_data = {} from boxtree.fmm import drive_fmm src_weights = np.random.rand(tree.nsources).astype(tree.coord_dtype) - drive_fmm(wrangler, (src_weights,), timing_data=timing_data) + drive_fmm(actx, wrangler, (src_weights,), timing_data=timing_data) timing_results.append(timing_data) @@ -560,7 +560,7 @@ def test_cost_model_op_counts_agree_with_constantone_wrangler( timing_data = {} from boxtree.fmm import drive_fmm src_weights = np.random.rand(tree.nsources).astype(tree.coord_dtype) - drive_fmm(wrangler, (src_weights,), timing_data=timing_data) + drive_fmm(actx, wrangler, (src_weights,), timing_data=timing_data) cost_model = FMMCostModel( translation_cost_model_factory=OpCountingTranslationCostModel diff --git a/test/test_distributed.py b/test/test_distributed.py index 56997011..8357cd5b 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -114,7 +114,7 @@ def fmm_level_to_order(tree, level): # Compute FMM with one MPI rank from boxtree.fmm import drive_fmm - pot_fmm = drive_fmm(wrangler, [sources_weights]) * 2 * np.pi + pot_fmm = drive_fmm(actx, wrangler, [sources_weights]) * 2 * np.pi # Compute FMM using the distributed implementation @@ -133,7 +133,7 @@ def wrangler_factory(local_traversal, global_traversal): timing_data = {} pot_dfmm = distribued_fmm_info.drive_dfmm( - [sources_weights], timing_data=timing_data) + actx, [sources_weights], timing_data=timing_data) assert timing_data # Uncomment the following section to print the time taken of each stage @@ -265,7 +265,7 @@ def wrangler_factory(local_traversal, global_traversal): distributed_fmm_info = DistributedFMMRunner( actx, tree, tg, wrangler_factory, comm=MPI.COMM_WORLD) - pot_dfmm = distributed_fmm_info.drive_dfmm([sources_weights]) + pot_dfmm = distributed_fmm_info.drive_dfmm(actx, [sources_weights]) if rank == 0: assert (np.all(pot_dfmm == nsources)) diff --git a/test/test_fmm.py b/test/test_fmm.py index 0fb6f31a..cd0aa826 100644 --- a/test/test_fmm.py +++ b/test/test_fmm.py @@ -48,7 +48,8 @@ # {{{ ref fmmlib pot computation -def get_fmmlib_ref_pot(wrangler, weights, sources_host, targets_host, +def get_fmmlib_ref_pot( + actx, wrangler, weights, sources_host, targets_host, helmholtz_k, dipole_vec=None): dims = sources_host.shape[0] eqn_letter = "h" if helmholtz_k else "l" @@ -84,10 +85,10 @@ def get_fmmlib_ref_pot(wrangler, weights, sources_host, targets_host, kwargs["zk"] = helmholtz_k return wrangler.finalize_potentials( + actx, fmmlib_routine( sources=sources_host, targets=targets_host, - **kwargs)[0], - template_ary=weights) + **kwargs)[0]) # }}} @@ -274,7 +275,7 @@ def test_fmm_completeness(actx_factory, dims, nsources_req, ntargets_req, == weights) from boxtree.fmm import drive_fmm - pot = drive_fmm(wrangler, (weights,)) + pot = drive_fmm(actx, wrangler, (weights,)) if filter_kind: pot = pot[actx.to_numpy(flags) > 0] @@ -292,7 +293,7 @@ def test_fmm_completeness(actx_factory, dims, nsources_req, ntargets_req, for i in range(nsources): unit_vec = np.zeros(nsources, dtype=dtype) unit_vec[i] = 1 - mat[:, i] = drive_fmm(host_trav, wrangler, (unit_vec,)) + mat[:, i] = drive_fmm(actx, wrangler, (unit_vec,)) pb.progress() pb.finished() @@ -406,8 +407,8 @@ def test_pyfmmlib_fmm(actx_factory, dims, use_dipoles, helmholtz_k): p_normal(actx, ntargets, dims, dtype, seed=18) + np.array([2, 0, 0])[:dims]) - sources_host = particle_array_to_host(actx, sources) - targets_host = particle_array_to_host(actx, targets) + sources_host = np.stack(actx.to_numpy(sources)) + targets_host = np.stack(actx.to_numpy(targets)) from boxtree import TreeBuilder tb = TreeBuilder(actx) @@ -461,7 +462,7 @@ def fmm_level_to_order(tree, lev): from boxtree.fmm import drive_fmm timing_data = {} - pot = drive_fmm(wrangler, (weights,), timing_data=timing_data) + pot = drive_fmm(actx, wrangler, (weights,), timing_data=timing_data) print(timing_data) assert timing_data @@ -469,8 +470,8 @@ def fmm_level_to_order(tree, lev): logger.info("computing direct (reference) result") - ref_pot = get_fmmlib_ref_pot(wrangler, weights, sources_host.T, - targets_host.T, helmholtz_k, dipole_vec) + ref_pot = get_fmmlib_ref_pot(actx, wrangler, weights, sources_host, + targets_host, helmholtz_k, dipole_vec) rel_err = la.norm(pot - ref_pot, np.inf) / la.norm(ref_pot, np.inf) logger.info("relative l2 error vs fmmlib direct: %g", rel_err) @@ -504,15 +505,17 @@ def fmm_level_to_order(tree, lev): if use_dipoles: knl = DirectionalSourceDerivative(knl) - sumpy_extra_kwargs["src_derivative_dir"] = dipole_vec + sumpy_extra_kwargs["src_derivative_dir"] = actx.from_numpy(dipole_vec) - p2p = P2P(actx.context, - [knl], - exclude_self=False) + p2p = P2P(target_kernels=[knl], exclude_self=False) - evt, (sumpy_ref_pot,) = p2p( - actx.queue, targets, sources, (weights,), - out_host=True, **sumpy_extra_kwargs) + result = p2p( + actx, + targets, + sources, + (actx.from_numpy(weights),), + **sumpy_extra_kwargs) + sumpy_ref_pot = actx.to_numpy(result["result_s0"]) sumpy_rel_err = ( la.norm(pot - sumpy_ref_pot, np.inf) @@ -584,14 +587,14 @@ def fmm_level_to_order(tree, lev): rotation_data=FMMLibRotationData(actx, trav)) from boxtree.fmm import drive_fmm - pot = drive_fmm(wrangler, (weights,)) + pot = drive_fmm(actx, wrangler, (weights,)) assert not np.isnan(pot).any() # {{{ ref fmmlib computation logger.info("computing direct (reference) result") - ref_pot = get_fmmlib_ref_pot(wrangler, weights, sources, targets, + ref_pot = get_fmmlib_ref_pot(actx, wrangler, weights, sources, targets, helmholtz_k) rel_err = la.norm(pot - ref_pot, np.inf) / la.norm(ref_pot, np.inf) @@ -657,7 +660,7 @@ def test_interaction_list_particle_count_thresholding(actx_factory, enable_exten tree_indep = ConstantOneTreeIndependentDataForWrangler() wrangler = ConstantOneExpansionWrangler(tree_indep, host_trav) - pot = drive_fmm(wrangler, (weights,)) + pot = drive_fmm(actx, wrangler, (weights,)) assert np.all(pot == weights_sum) @@ -711,7 +714,7 @@ def test_fmm_float32(actx_factory, enable_extents): tree_indep = ConstantOneTreeIndependentDataForWrangler() wrangler = ConstantOneExpansionWrangler(tree_indep, host_trav) - pot = drive_fmm(wrangler, (weights,)) + pot = drive_fmm(actx, wrangler, (weights,)) assert np.all(pot == weights_sum) @@ -784,11 +787,11 @@ def fmm_level_to_order(tree, lev): baseline_timing_data = {} baseline_pot = drive_fmm( - baseline_wrangler, (weights,), timing_data=baseline_timing_data) + actx, baseline_wrangler, (weights,), timing_data=baseline_timing_data) optimized_timing_data = {} optimized_pot = drive_fmm( - optimized_wrangler, (weights,), timing_data=optimized_timing_data) + actx, optimized_wrangler, (weights,), timing_data=optimized_timing_data) baseline_time = baseline_timing_data["multipole_to_local"]["process_elapsed"] if baseline_time is not None: