Skip to content

Commit

Permalink
fix some memoized kernel keys
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Sep 18, 2022
1 parent deb0b1f commit 189b9fc
Show file tree
Hide file tree
Showing 9 changed files with 366 additions and 288 deletions.
312 changes: 171 additions & 141 deletions boxtree/area_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from pyopencl.elementwise import ElementwiseTemplate

from arraycontext import Array
from pytools import memoize_in, ProcessLogger
from pytools import ProcessLogger, memoize_on_first_arg
from mako.template import Template

from boxtree.tree import Tree
Expand Down Expand Up @@ -593,6 +593,78 @@ class AreaQueryResult:
leaves_near_ball_lists: Array


@memoize_on_first_arg
def get_area_query_kernel(
actx: PyOpenCLArrayContext,
dimensions: int,
coord_dtype: "np.dtype",
box_id_dtype: "np.dtype",
ball_id_dtype: "np.dtype",
peer_list_idx_dtype: "np.dtype",
max_levels: int,
root_extent_stretch_factor: float):
from pyopencl.tools import dtype_to_ctype

from boxtree import box_flags_enum
from boxtree.tools import AXIS_NAMES
from boxtree.traversal import TRAVERSAL_PREAMBLE_TEMPLATE

logger.debug("start building area query kernel")

template = Template(
TRAVERSAL_PREAMBLE_TEMPLATE
+ AREA_QUERY_TEMPLATE,
strict_undefined=True)

render_vars = dict(
np=np,
dimensions=dimensions,
dtype_to_ctype=dtype_to_ctype,
box_id_dtype=box_id_dtype,
particle_id_dtype=None,
coord_dtype=coord_dtype,
get_coord_vec_dtype=get_coord_vec_dtype,
cvec_sub=partial(coord_vec_subscript_code, dimensions),
max_levels=max_levels,
AXIS_NAMES=AXIS_NAMES,
box_flags_enum=box_flags_enum,
peer_list_idx_dtype=peer_list_idx_dtype,
ball_id_dtype=ball_id_dtype,
debug=False,
root_extent_stretch_factor=root_extent_stretch_factor)

from boxtree.tools import VectorArg, ScalarArg
arg_decls = [
VectorArg(coord_dtype, "box_centers", with_offset=False),
ScalarArg(coord_dtype, "root_extent"),
VectorArg(np.uint8, "box_levels"),
ScalarArg(box_id_dtype, "aligned_nboxes"),
VectorArg(box_id_dtype, "box_child_ids", with_offset=False),
VectorArg(box_flags_enum.dtype, "box_flags"),
VectorArg(peer_list_idx_dtype, "peer_list_starts"),
VectorArg(box_id_dtype, "peer_lists"),
VectorArg(coord_dtype, "ball_radii"),
] + [
ScalarArg(coord_dtype, "bbox_min_"+ax)
for ax in AXIS_NAMES[:dimensions]
] + [
VectorArg(coord_dtype, "ball_"+ax)
for ax in AXIS_NAMES[:dimensions]]

from pyopencl.algorithm import ListOfListsBuilder
area_query_knl = ListOfListsBuilder(
actx.context,
[("leaves", box_id_dtype)],
str(template.render(**render_vars)),
arg_decls=arg_decls,
name_prefix="area_query",
count_sharing={},
complex_kernel=True)

logger.debug("done building area query kernel")
return area_query_knl


def build_area_query(
actx: PyOpenCLArrayContext, tree: Tree,
ball_centers, ball_radii, peer_lists=None) -> AreaQueryResult:
Expand Down Expand Up @@ -633,73 +705,11 @@ def build_area_query(

# {{{ area query

@memoize_in(actx, (
build_area_query,
area_query_knl = get_area_query_kernel(
actx,
tree.dimensions, tree.coord_dtype, tree.box_id_dtype,
ball_id_dtype, peer_list_idx_dtype, max_levels))
def get_area_query_kernel():
from pyopencl.tools import dtype_to_ctype

from boxtree import box_flags_enum
from boxtree.tools import AXIS_NAMES
from boxtree.traversal import TRAVERSAL_PREAMBLE_TEMPLATE

logger.debug("start building area query kernel")

template = Template(
TRAVERSAL_PREAMBLE_TEMPLATE
+ AREA_QUERY_TEMPLATE,
strict_undefined=True)

render_vars = dict(
np=np,
dimensions=tree.dimensions,
dtype_to_ctype=dtype_to_ctype,
box_id_dtype=tree.box_id_dtype,
particle_id_dtype=None,
coord_dtype=tree.coord_dtype,
get_coord_vec_dtype=get_coord_vec_dtype,
cvec_sub=partial(coord_vec_subscript_code, tree.dimensions),
max_levels=max_levels,
AXIS_NAMES=AXIS_NAMES,
box_flags_enum=box_flags_enum,
peer_list_idx_dtype=peer_list_idx_dtype,
ball_id_dtype=ball_id_dtype,
debug=False,
root_extent_stretch_factor=tree.root_extent_stretch_factor)

from boxtree.tools import VectorArg, ScalarArg
arg_decls = [
VectorArg(tree.coord_dtype, "box_centers", with_offset=False),
ScalarArg(tree.coord_dtype, "root_extent"),
VectorArg(np.uint8, "box_levels"),
ScalarArg(tree.box_id_dtype, "aligned_nboxes"),
VectorArg(tree.box_id_dtype, "box_child_ids", with_offset=False),
VectorArg(box_flags_enum.dtype, "box_flags"),
VectorArg(peer_list_idx_dtype, "peer_list_starts"),
VectorArg(tree.box_id_dtype, "peer_lists"),
VectorArg(tree.coord_dtype, "ball_radii"),
] + [
ScalarArg(tree.coord_dtype, "bbox_min_"+ax)
for ax in AXIS_NAMES[:tree.dimensions]
] + [
VectorArg(tree.coord_dtype, "ball_"+ax)
for ax in AXIS_NAMES[:tree.dimensions]]

from pyopencl.algorithm import ListOfListsBuilder
area_query_knl = ListOfListsBuilder(
actx.context,
[("leaves", tree.box_id_dtype)],
str(template.render(**render_vars)),
arg_decls=arg_decls,
name_prefix="area_query",
count_sharing={},
complex_kernel=True)

logger.debug("done building area query kernel")
return area_query_knl

area_query_knl = get_area_query_kernel()
ball_id_dtype, peer_list_idx_dtype, max_levels,
tree.root_extent_stretch_factor)

with ProcessLogger(logger, "area query"):
result, _ = area_query_knl(
Expand Down Expand Up @@ -796,6 +806,8 @@ def build_leaves_to_balls_lookup(

# {{{ build lookup

from pytools import memoize_in

@memoize_in(actx, (build_leaves_to_balls_lookup, tree.box_id_dtype))
def get_starts_expander_kernel():
return STARTS_EXPANDER_TEMPLATE.build(
Expand Down Expand Up @@ -878,6 +890,25 @@ def __call__(self,
actx, tree, ball_centers, ball_radii, peer_lists)


@memoize_on_first_arg
def get_space_invader_query_kernel(
actx: PyOpenCLArrayContext,
dimensions: int,
coord_dtype: "np.dtype",
box_id_dtype: "np.dtype",
peer_list_starts_dtype: "np.dtype",
max_levels: int,
root_extent_stretch_factor: float):
return SPACE_INVADER_QUERY_TEMPLATE.generate(
actx.context,
dimensions,
coord_dtype,
box_id_dtype,
peer_list_starts_dtype,
max_levels,
root_extent_stretch_factor=root_extent_stretch_factor)


def build_space_invader_query(
actx: PyOpenCLArrayContext, tree: Tree,
ball_centers, ball_radii, peer_lists=None) -> Array:
Expand Down Expand Up @@ -938,21 +969,12 @@ def build_space_invader_query(

# {{{ build query

@memoize_in(actx, (
build_space_invader_query,
space_invader_query_knl = get_space_invader_query_kernel(
actx,
tree.dimensions, tree.coord_dtype, tree.box_id_dtype,
peer_lists.peer_list_starts.dtype, max_levels))
def get_space_invader_query_kernel():
return SPACE_INVADER_QUERY_TEMPLATE.generate(
actx.context,
tree.dimensions,
tree.coord_dtype,
tree.box_id_dtype,
peer_lists.peer_list_starts.dtype,
max_levels,
root_extent_stretch_factor=tree.root_extent_stretch_factor)

space_invader_query_knl = get_space_invader_query_kernel()
peer_lists.peer_list_starts.dtype,
max_levels, tree.root_extent_stretch_factor,
)

with ProcessLogger(logger, "space invader query"):
outer_space_invader_dists = actx.zeros(tree.nboxes, np.float32)
Expand Down Expand Up @@ -1020,6 +1042,69 @@ class PeerListLookup:
peer_lists: Array


@memoize_on_first_arg
def get_peer_list_finder_kernel(
actx: PyOpenCLArrayContext,
dimensions: int,
coord_dtype: "np.dtype",
box_id_dtype: "np.dtype",
max_levels: int):
from pyopencl.tools import dtype_to_ctype

from boxtree import box_flags_enum
from boxtree.tools import AXIS_NAMES
from boxtree.traversal import (
TRAVERSAL_PREAMBLE_TEMPLATE, HELPER_FUNCTION_TEMPLATE)

logger.debug("start building peer list finder kernel")

template = Template(
TRAVERSAL_PREAMBLE_TEMPLATE
+ HELPER_FUNCTION_TEMPLATE
+ PEER_LIST_FINDER_TEMPLATE,
strict_undefined=True)

render_vars = dict(
np=np,
dimensions=dimensions,
dtype_to_ctype=dtype_to_ctype,
box_id_dtype=box_id_dtype,
particle_id_dtype=None,
coord_dtype=coord_dtype,
get_coord_vec_dtype=get_coord_vec_dtype,
cvec_sub=partial(coord_vec_subscript_code, dimensions),
max_levels=max_levels,
AXIS_NAMES=AXIS_NAMES,
box_flags_enum=box_flags_enum,
debug=False,
# For calls to the helper is_adjacent_or_overlapping()
targets_have_extent=False,
sources_have_extent=False)

from boxtree.tools import VectorArg, ScalarArg
arg_decls = [
VectorArg(coord_dtype, "box_centers", with_offset=False),
ScalarArg(coord_dtype, "root_extent"),
VectorArg(np.uint8, "box_levels"),
ScalarArg(box_id_dtype, "aligned_nboxes"),
VectorArg(box_id_dtype, "box_child_ids", with_offset=False),
VectorArg(box_flags_enum.dtype, "box_flags"),
]

from pyopencl.algorithm import ListOfListsBuilder
peer_list_finder_knl = ListOfListsBuilder(
actx.context,
[("peers", box_id_dtype)],
str(template.render(**render_vars)),
arg_decls=arg_decls,
name_prefix="find_peer_lists",
count_sharing={},
complex_kernel=True)

logger.debug("done building peer list finder kernel")
return peer_list_finder_knl


def build_peer_list(actx: PyOpenCLArrayContext, tree: Tree) -> PeerListLookup:
"""Builds a look-up table from box numbers to peer boxes. The full definition
[1]_ of a peer box is as follows:
Expand All @@ -1044,66 +1129,11 @@ def build_peer_list(actx: PyOpenCLArrayContext, tree: Tree) -> PeerListLookup:
# a stack bound. Rounding avoids too many kernel versions.
max_levels = div_ceil(tree.nlevels, 10) * 10

@memoize_in(actx, (
build_peer_list, tree.dimensions, tree.coord_dtype, tree.box_id_dtype,
max_levels))
def get_peer_list_finder_kernel():
from pyopencl.tools import dtype_to_ctype

from boxtree import box_flags_enum
from boxtree.tools import AXIS_NAMES
from boxtree.traversal import (
TRAVERSAL_PREAMBLE_TEMPLATE, HELPER_FUNCTION_TEMPLATE)

logger.debug("start building peer list finder kernel")

template = Template(
TRAVERSAL_PREAMBLE_TEMPLATE
+ HELPER_FUNCTION_TEMPLATE
+ PEER_LIST_FINDER_TEMPLATE,
strict_undefined=True)

render_vars = dict(
np=np,
dimensions=tree.dimensions,
dtype_to_ctype=dtype_to_ctype,
box_id_dtype=tree.box_id_dtype,
particle_id_dtype=None,
coord_dtype=tree.coord_dtype,
get_coord_vec_dtype=get_coord_vec_dtype,
cvec_sub=partial(coord_vec_subscript_code, tree.dimensions),
max_levels=max_levels,
AXIS_NAMES=AXIS_NAMES,
box_flags_enum=box_flags_enum,
debug=False,
# For calls to the helper is_adjacent_or_overlapping()
targets_have_extent=False,
sources_have_extent=False)

from boxtree.tools import VectorArg, ScalarArg
arg_decls = [
VectorArg(tree.coord_dtype, "box_centers", with_offset=False),
ScalarArg(tree.coord_dtype, "root_extent"),
VectorArg(np.uint8, "box_levels"),
ScalarArg(tree.box_id_dtype, "aligned_nboxes"),
VectorArg(tree.box_id_dtype, "box_child_ids", with_offset=False),
VectorArg(box_flags_enum.dtype, "box_flags"),
]

from pyopencl.algorithm import ListOfListsBuilder
peer_list_finder_knl = ListOfListsBuilder(
actx.context,
[("peers", tree.box_id_dtype)],
str(template.render(**render_vars)),
arg_decls=arg_decls,
name_prefix="find_peer_lists",
count_sharing={},
complex_kernel=True)

logger.debug("done building peer list finder kernel")
return peer_list_finder_knl

peer_list_finder_knl = get_peer_list_finder_kernel()
peer_list_finder_knl = get_peer_list_finder_kernel(
actx,
tree.dimensions, tree.coord_dtype, tree.box_id_dtype,
max_levels,
)

with ProcessLogger(logger, "find peer lists"):
result, evt = peer_list_finder_knl(
Expand Down
Loading

0 comments on commit 189b9fc

Please sign in to comment.