Skip to content

Commit

Permalink
update fmm interface for sumpy
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Sep 10, 2022
1 parent a0c9bcb commit d40443a
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 62 deletions.
35 changes: 26 additions & 9 deletions boxtree/constant_one.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
"""

import numpy as np

from boxtree.array_context import PyOpenCLArrayContext
from boxtree.fmm import TreeIndependentDataForWrangler, ExpansionWranglerInterface
from boxtree.timing import DummyTimingFuture

Expand Down Expand Up @@ -83,7 +85,9 @@ def local_expansions_view(self, local_exps, level):
def timing_future(ops):
return DummyTimingFuture.from_op_count(ops)

def form_multipoles(self, level_start_source_box_nrs, source_boxes,
def form_multipoles(self, actx: PyOpenCLArrayContext,
level_start_source_box_nrs,
source_boxes,
src_weight_vecs):
src_weights, = src_weight_vecs
mpoles = self.multipole_expansion_zeros()
Expand All @@ -96,8 +100,10 @@ def form_multipoles(self, level_start_source_box_nrs, source_boxes,

return mpoles, self.timing_future(ops)

def coarsen_multipoles(self, level_start_source_parent_box_nrs,
source_parent_boxes, mpoles):
def coarsen_multipoles(self, actx: PyOpenCLArrayContext,
level_start_source_parent_box_nrs,
source_parent_boxes,
mpoles):
tree = self.tree
ops = 0

Expand All @@ -119,7 +125,8 @@ def coarsen_multipoles(self, level_start_source_parent_box_nrs,

return mpoles, self.timing_future(ops)

def eval_direct(self, target_boxes, neighbor_sources_starts,
def eval_direct(self, actx: PyOpenCLArrayContext,
target_boxes, neighbor_sources_starts,
neighbor_sources_lists, src_weight_vecs):
src_weights, = src_weight_vecs
pot = self.output_zeros()
Expand All @@ -144,6 +151,7 @@ def eval_direct(self, target_boxes, neighbor_sources_starts,
return pot, self.timing_future(ops)

def multipole_to_local(self,
actx: PyOpenCLArrayContext,
level_start_target_or_target_parent_box_nrs,
target_or_target_parent_boxes,
starts, lists, mpole_exps):
Expand All @@ -164,7 +172,9 @@ def multipole_to_local(self,
return local_exps, self.timing_future(ops)

def eval_multipoles(self,
target_boxes_by_source_level, from_sep_smaller_nonsiblings_by_level,
actx: PyOpenCLArrayContext,
target_boxes_by_source_level,
from_sep_smaller_nonsiblings_by_level,
mpole_exps):
pot = self.output_zeros()
ops = 0
Expand All @@ -186,8 +196,10 @@ def eval_multipoles(self,
return pot, self.timing_future(ops)

def form_locals(self,
actx: PyOpenCLArrayContext,
level_start_target_or_target_parent_box_nrs,
target_or_target_parent_boxes, starts, lists, src_weight_vecs):
target_or_target_parent_boxes,
starts, lists, src_weight_vecs):
src_weights, = src_weight_vecs
local_exps = self.local_expansion_zeros()
ops = 0
Expand All @@ -209,7 +221,9 @@ def form_locals(self,

return local_exps, self.timing_future(ops)

def refine_locals(self, level_start_target_or_target_parent_box_nrs,
def refine_locals(self,
actx: PyOpenCLArrayContext,
level_start_target_or_target_parent_box_nrs,
target_or_target_parent_boxes, local_exps):
ops = 0

Expand All @@ -222,7 +236,10 @@ def refine_locals(self, level_start_target_or_target_parent_box_nrs,

return local_exps, self.timing_future(ops)

def eval_locals(self, level_start_target_box_nrs, target_boxes, local_exps):
def eval_locals(self,
actx: PyOpenCLArrayContext,
level_start_target_box_nrs,
target_boxes, local_exps):
pot = self.output_zeros()
ops = 0

Expand All @@ -233,7 +250,7 @@ def eval_locals(self, level_start_target_box_nrs, target_boxes, local_exps):

return pot, self.timing_future(ops)

def finalize_potentials(self, potentials, template_ary):
def finalize_potentials(self, actx: PyOpenCLArrayContext, potentials):
return potentials

# }}}
Expand Down
9 changes: 6 additions & 3 deletions boxtree/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,14 @@ def __init__(self, array_context: PyOpenCLArrayContext, global_tree,
array_context, global_tree, traversal_builder, wrangler_factory,
calibration_params, comm)

def drive_dfmm(self, source_weights, timing_data=None):
"""Calculate potentials at target points.
"""
def drive_dfmm(self,
actx: PyOpenCLArrayContext,
source_weights,
timing_data=None):
"""Calculate potentials at target points."""
from boxtree.fmm import drive_fmm
return drive_fmm(
actx,
self.wrangler, source_weights,
timing_data=timing_data,
global_src_idx_all_ranks=self.src_idx_all_ranks,
Expand Down
30 changes: 26 additions & 4 deletions boxtree/fmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from boxtree.tree import Tree
from boxtree.traversal import FMMTraversalInfo
from boxtree.array_context import PyOpenCLArrayContext

import logging
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -155,6 +156,7 @@ def local_expansions_view(self, local_exps, level):

@abstractmethod
def form_multipoles(self,
actx: PyOpenCLArrayContext,
level_start_source_box_nrs, source_boxes,
src_weight_vecs):
"""Return an expansions array
Expand All @@ -167,6 +169,7 @@ def form_multipoles(self,

@abstractmethod
def coarsen_multipoles(self,
actx: PyOpenCLArrayContext,
level_start_source_parent_box_nrs,
source_parent_boxes, mpoles):
"""For each box in *source_parent_boxes*,
Expand All @@ -179,6 +182,7 @@ def coarsen_multipoles(self,

@abstractmethod
def eval_direct(self,
actx: PyOpenCLArrayContext,
target_boxes, neighbor_sources_starts,
neighbor_sources_lists, src_weight_vecs):
"""For each box in *target_boxes*, evaluate the influence of the
Expand All @@ -191,6 +195,7 @@ def eval_direct(self,

@abstractmethod
def multipole_to_local(self,
actx: PyOpenCLArrayContext,
level_start_target_or_target_parent_box_nrs,
target_or_target_parent_boxes,
starts, lists, mpole_exps):
Expand All @@ -205,6 +210,7 @@ def multipole_to_local(self,

@abstractmethod
def eval_multipoles(self,
actx: PyOpenCLArrayContext,
target_boxes_by_source_level, from_sep_smaller_by_level, mpole_exps):
"""For a level *i*, each box in *target_boxes_by_source_level[i]*, evaluate
the multipole expansion in *mpole_exps* in the nearby boxes given in
Expand All @@ -218,6 +224,7 @@ def eval_multipoles(self,

@abstractmethod
def form_locals(self,
actx: PyOpenCLArrayContext,
level_start_target_or_target_parent_box_nrs,
target_or_target_parent_boxes, starts, lists, src_weight_vecs):
"""For each box in *target_or_target_parent_boxes*, form local
Expand All @@ -232,6 +239,7 @@ def form_locals(self,

@abstractmethod
def refine_locals(self,
actx: PyOpenCLArrayContext,
level_start_target_or_target_parent_box_nrs,
target_or_target_parent_boxes, local_exps):
"""For each box in *child_boxes*,
Expand All @@ -243,6 +251,7 @@ def refine_locals(self,

@abstractmethod
def eval_locals(self,
actx: PyOpenCLArrayContext,
level_start_target_box_nrs, target_boxes, local_exps):
"""For each box in *target_boxes*, evaluate the local expansion in
*local_exps* and return a new potential array.
Expand All @@ -254,7 +263,7 @@ def eval_locals(self,
# }}}

@abstractmethod
def finalize_potentials(self, potentials, template_ary):
def finalize_potentials(self, actx: PyOpenCLArrayContext, potentials):
"""
Postprocess the reordered potentials. This is where global scaling
factors could be applied. This is distinct from :meth:`reorder_potentials`
Expand Down Expand Up @@ -324,9 +333,12 @@ def communicate_mpoles(self, mpole_exps, return_stats=False):
# }}}


def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs,
def drive_fmm(actx: PyOpenCLArrayContext,
wrangler: ExpansionWranglerInterface,
src_weight_vecs, *,
timing_data=None,
global_src_idx_all_ranks=None, global_tgt_idx_all_ranks=None):
global_src_idx_all_ranks=None,
global_tgt_idx_all_ranks=None):
"""Top-level driver routine for a fast multipole calculation.
In part, this is intended as a template for custom FMMs, in the sense that
Expand Down Expand Up @@ -382,6 +394,7 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs,
# {{{ "Step 2.1:" Construct local multipoles

mpole_exps, timing_future = wrangler.form_multipoles(
actx,
traversal.level_start_source_box_nrs,
traversal.source_boxes,
src_weight_vecs)
Expand All @@ -393,6 +406,7 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs,
# {{{ "Step 2.2:" Propagate multipoles upward

mpole_exps, timing_future = wrangler.coarsen_multipoles(
actx,
traversal.level_start_source_parent_box_nrs,
traversal.source_parent_boxes,
mpole_exps)
Expand All @@ -408,6 +422,7 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs,
# {{{ "Stage 3:" Direct evaluation from neighbor source boxes ("list 1")

potentials, timing_future = wrangler.eval_direct(
actx,
traversal.target_boxes,
traversal.neighbor_source_boxes_starts,
traversal.neighbor_source_boxes_lists,
Expand All @@ -422,6 +437,7 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs,
# {{{ "Stage 4:" translate separated siblings' ("list 2") mpoles to local

local_exps, timing_future = wrangler.multipole_to_local(
actx,
traversal.level_start_target_or_target_parent_box_nrs,
traversal.target_or_target_parent_boxes,
traversal.from_sep_siblings_starts,
Expand All @@ -440,6 +456,7 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs,
# contribution *out* of the downward-propagating local expansions)

mpole_result, timing_future = wrangler.eval_multipoles(
actx,
traversal.target_boxes_sep_smaller_by_source_level,
traversal.from_sep_smaller_by_level,
mpole_exps)
Expand All @@ -455,6 +472,7 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs,
"('list 3 close')")

direct_result, timing_future = wrangler.eval_direct(
actx,
traversal.target_boxes,
traversal.from_sep_close_smaller_starts,
traversal.from_sep_close_smaller_lists,
Expand All @@ -469,6 +487,7 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs,
# {{{ "Stage 6:" form locals for separated bigger source boxes ("list 4")

local_result, timing_future = wrangler.form_locals(
actx,
traversal.level_start_target_or_target_parent_box_nrs,
traversal.target_or_target_parent_boxes,
traversal.from_sep_bigger_starts,
Expand All @@ -481,6 +500,7 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs,

if traversal.from_sep_close_bigger_starts is not None:
direct_result, timing_future = wrangler.eval_direct(
actx,
traversal.target_boxes,
traversal.from_sep_close_bigger_starts,
traversal.from_sep_close_bigger_lists,
Expand All @@ -495,6 +515,7 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs,
# {{{ "Stage 7:" propagate local_exps downward

local_exps, timing_future = wrangler.refine_locals(
actx,
traversal.level_start_target_or_target_parent_box_nrs,
traversal.target_or_target_parent_boxes,
local_exps)
Expand All @@ -506,6 +527,7 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs,
# {{{ "Stage 8:" evaluate locals

local_result, timing_future = wrangler.eval_locals(
actx,
traversal.level_start_target_box_nrs,
traversal.target_boxes,
local_exps)
Expand All @@ -521,7 +543,7 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs,

result = wrangler.reorder_potentials(potentials)

result = wrangler.finalize_potentials(result, template_ary=src_weight_vecs[0])
result = wrangler.finalize_potentials(actx, result)

fmm_proc.done()

Expand Down
Loading

0 comments on commit d40443a

Please sign in to comment.