Skip to content

Commit

Permalink
fix porting changes
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Sep 10, 2022
1 parent 4a46bb6 commit 3a4b48b
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 22 deletions.
13 changes: 4 additions & 9 deletions boxtree/pyfmmlib_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,12 @@ def __init__(self, array_context: PyOpenCLArrayContext, trav):
self.trav = trav
self.tree = trav.tree

@property
@memoize_method
def rotation_classes_builder(self):
from boxtree.rotation_classes import RotationClassesBuilder
return RotationClassesBuilder(self._setup_actx)

@memoize_method
def build_rotation_classes_lists(self):
trav = self._setup_actx.from_numpy(self.trav)
tree = self._setup_actx.from_numpy(self.tree)
return self.rotation_classes_builder(self._setup_actx, trav, tree)[0]
from boxtree.rotation_classes import build_rotation_classes
actx = self._setup_actx
return build_rotation_classes(
actx, actx.from_numpy(self.trav), actx.from_numpy(self.tree))

@memoize_method
def m2l_rotation_lists(self):
Expand Down
6 changes: 3 additions & 3 deletions boxtree/rotation_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,9 @@ def build_rotation_classes(
actx: PyOpenCLArrayContext,
trav: FMMTraversalInfo, tree: Tree) -> RotationClassesInfo:
"""Build rotation classes for List 2 translations."""
from boxtree.translation_classes import compute_used_tranlation_classes
from boxtree.translation_classes import compute_used_translation_classes
translation_class_is_used, translation_classes_lists = (
compute_used_tranlation_classes(actx, trav, tree,
compute_used_translation_classes(actx, trav, tree,
is_translation_per_level=False))

d = tree.dimensions
Expand All @@ -176,7 +176,7 @@ def build_rotation_classes(

translation_class_to_rotation_class, rotation_angles = (
translation_classes_to_rotation_classes_and_angles(
n, d, used_translation_classes))
used_translation_classes, n, d))

# There should be no more than 2^(d-1) * (2n+1)^d distinct rotation
# classes, since that is an upper bound on the number of distinct
Expand Down
6 changes: 3 additions & 3 deletions boxtree/translation_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def translation_class_to_normalized_vector(
return result


def compute_used_tranlation_classes(
def compute_used_translation_classes(
actx: PyOpenCLArrayContext, trav: FMMTraversalInfo, tree: Tree, *,
is_translation_per_level: bool):
# {{{ compute translation classes for list 2
Expand All @@ -282,7 +282,7 @@ def compute_used_tranlation_classes(
ntranslation_classes = ntranslation_classes * tree.nlevels

@memoize_in(actx, (
compute_used_tranlation_classes,
compute_used_translation_classes,
dimensions, well_sep_is_n_away, tree.box_id_dtype,
tree.box_level_dtype, coord_dtype, is_translation_per_level))
def get_translation_class_finder_knl():
Expand Down Expand Up @@ -356,7 +356,7 @@ def build_translation_classes(actx: PyOpenCLArrayContext,
is_translation_per_level: bool = True) -> TranslationClassesInfo:
"""Build translation classes for List 2 translations."""
translation_class_is_used, translation_classes_lists = (
compute_used_tranlation_classes(actx, trav, tree,
compute_used_translation_classes(actx, trav, tree,
is_translation_per_level=is_translation_per_level))

well_sep_is_n_away = trav.well_sep_is_n_away
Expand Down
14 changes: 7 additions & 7 deletions test/test_traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,17 +342,14 @@ def test_from_sep_siblings_translation_and_rotation_classes(
# {{{ build traversal

from boxtree.traversal import FMMTraversalBuilder
from boxtree.rotation_classes import RotationClassesBuilder
from boxtree.translation_classes import TranslationClassesBuilder
from boxtree.rotation_classes import build_rotation_classes
from boxtree.translation_classes import build_translation_classes

tg = FMMTraversalBuilder(actx, well_sep_is_n_away=well_sep_is_n_away)
trav, _ = tg(actx, tree)

rb = RotationClassesBuilder(actx)
result, _ = rb(actx, trav, tree)

tb = TranslationClassesBuilder(actx)
result_tb, _ = tb(actx, trav, tree)
result = build_rotation_classes(actx, trav, tree)
result_tb = build_translation_classes(actx, trav, tree)

rot_classes = actx.to_numpy(
result.from_sep_siblings_rotation_classes)
Expand All @@ -364,6 +361,9 @@ def test_from_sep_siblings_translation_and_rotation_classes(
distance_vectors = actx.to_numpy(
result_tb.from_sep_siblings_translation_class_to_distance_vector)

print(rot_classes)
breakpoint()

tree = actx.to_numpy(tree)
trav = actx.to_numpy(trav)

Expand Down

0 comments on commit 3a4b48b

Please sign in to comment.