From cfbdf4f1cef2ae0f782edbd557849fe6b5dc916c Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Thu, 13 Oct 2022 11:55:13 -0500 Subject: [PATCH 1/2] fix memoization key. The frozen type of an actx is sensitive to actx's type --- meshmode/discretization/connection/direct.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/meshmode/discretization/connection/direct.py b/meshmode/discretization/connection/direct.py index 9eb0fae8a..0cfcd84dd 100644 --- a/meshmode/discretization/connection/direct.py +++ b/meshmode/discretization/connection/direct.py @@ -356,7 +356,7 @@ def __init__(self, # {{{ _resample_matrix @keyed_memoize_method(key=lambda actx, to_group_index, ibatch_index: - (to_group_index, ibatch_index)) + (type(actx), to_group_index, ibatch_index)) def _resample_matrix(self, actx: ArrayContext, to_group_index: int, ibatch_index: int): import modepy as mp @@ -435,7 +435,8 @@ def _resample_point_pick_indices(self, to_group_index: int, ibatch_index: int, return result @keyed_memoize_method(lambda actx, to_group_index, ibatch_index, - tol_multiplier=None: (to_group_index, ibatch_index, tol_multiplier)) + tol_multiplier=None: (type(actx), to_group_index, + ibatch_index, tol_multiplier)) def _frozen_resample_point_pick_indices(self, actx: ArrayContext, to_group_index: int, ibatch_index: int, tol_multiplier: Optional[float] = None): From 0591cbb788de716fa6de0ae34340eadd97330c63 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Thu, 13 Oct 2022 11:56:16 -0500 Subject: [PATCH 2/2] freeze constant arrays --- meshmode/discretization/connection/direct.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/meshmode/discretization/connection/direct.py b/meshmode/discretization/connection/direct.py index 0cfcd84dd..7052a534a 100644 --- a/meshmode/discretization/connection/direct.py +++ b/meshmode/discretization/connection/direct.py @@ -232,6 +232,12 @@ class _FromGroupPickData(Generic[ArrayT]): from_element_indices: ArrayT is_surjective: bool + @keyed_memoize_method(key=lambda actx: type(actx)) + def indexed_dof_pick_lists(self, actx): + assert actx.permits_advanced_indexing + return actx.freeze( + actx.thaw(self.dof_pick_lists)[actx.thaw(self.dof_pick_list_indices)]) + # }}} @@ -737,8 +743,7 @@ def group_pick_knl(is_surjective: bool): grp_ary_contrib = ary[fgpd.from_group_index][ _reshape_and_preserve_tags( actx, from_element_indices, (-1, 1)), - actx.thaw(fgpd.dof_pick_lists)[ - actx.thaw(fgpd.dof_pick_list_indices)] + actx.thaw(fgpd.indexed_dof_pick_lists(actx)) ] if not fgpd.is_surjective: