Skip to content

Commit

Permalink
Fixes for EmbeddedTransfer
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Nov 15, 2024
1 parent 0e4c04f commit 38e7f78
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 24 deletions.
32 changes: 14 additions & 18 deletions firedrake/mg/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ class Cache(object):
"""A caching object for work vectors and matrices.
:arg element: The element to use for the caching."""
def __init__(self, key):
self.embedding_element = get_embedding_dg_element(*key)
def __init__(self, ufl_element, value_shape):
self.embedding_element = get_embedding_dg_element(ufl_element, value_shape)
self._dat_versions = {}
self._V_DG_mass = {}
self._DG_inv_mass = {}
Expand Down Expand Up @@ -83,24 +83,20 @@ def _native_transfer(self, element, op):
return self.native_transfers.setdefault(element, ops)[op]
return None

def cache(self, key):
def cache(self, V):
key = (V.ufl_element(), V.value_shape)
try:
return self.caches[key]
except KeyError:
return self.caches.setdefault(key, TransferManager.Cache(key))

def get_cache_key(self, V):
elem = V.ufl_element()
value_shape = V.value_shape
return elem, value_shape
return self.caches.setdefault(key, TransferManager.Cache(*key))

def V_dof_weights(self, V):
"""Dof weights for averaging projection.
:arg V: function space to compute weights for.
:returns: A PETSc Vec.
"""
cache = self.cache(self.get_cache_key(V))
cache = self.cache(V)
key = V.dim()
try:
return cache._V_dof_weights[key]
Expand All @@ -125,7 +121,7 @@ def V_DG_mass(self, V, DG):
:arg DG: the DG space
:returns: A PETSc Mat mapping from V -> DG
"""
cache = self.cache(self.get_cache_key(V))
cache = self.cache(V)
key = V.dim()
try:
return cache._V_DG_mass[key]
Expand All @@ -140,7 +136,7 @@ def DG_inv_mass(self, DG):
:arg DG: the DG space
:returns: A PETSc Mat.
"""
cache = self.cache(self.get_cache_key(DG))
cache = self.cache(DG)
key = DG.dim()
try:
return cache._DG_inv_mass[key]
Expand All @@ -156,7 +152,7 @@ def V_approx_inv_mass(self, V, DG):
:arg DG: the DG space
:returns: A PETSc Mat mapping from V -> DG.
"""
cache = self.cache(self.get_cache_key(V))
cache = self.cache(DG)
key = V.dim()
try:
return cache._V_approx_inv_mass[key]
Expand All @@ -174,7 +170,7 @@ def V_inv_mass_ksp(self, V):
:arg V: a function space.
:returns: A PETSc KSP for inverting (V, V).
"""
cache = self.cache(self.get_cache_key(V))
cache = self.cache(V)
key = V.dim()
try:
return cache._V_inv_mass_ksp[key]
Expand All @@ -196,7 +192,7 @@ def DG_work(self, V):
:returns: A Function in the embedding DG space.
"""
needs_dual = ufl.duals.is_dual(V)
cache = self.cache(self.get_cache_key(V))
cache = self.cache(V)
key = (V.dim(), needs_dual)
try:
return cache._DG_work[key]
Expand All @@ -213,7 +209,7 @@ def work_vec(self, V):
:arg V: a function space.
:returns: A PETSc Vec for V.
"""
cache = self.cache(self.get_cache_key(V))
cache = self.cache(V)
key = V.dim()
try:
return cache._work_vec[key]
Expand All @@ -226,15 +222,15 @@ def requires_transfer(self, V, transfer_op, source, target):
key = (transfer_op, weakref.ref(source.dat), weakref.ref(target.dat))
dat_versions = (source.dat.dat_version, target.dat.dat_version)
try:
return self.cache(self.get_cache_key(V))._dat_versions[key] != dat_versions
return self.cache(V)._dat_versions[key] != dat_versions
except KeyError:
return True

def cache_dat_versions(self, V, transfer_op, source, target):
"""Record the returned dat_versions of the source and target."""
key = (transfer_op, weakref.ref(source.dat), weakref.ref(target.dat))
dat_versions = (source.dat.dat_version, target.dat.dat_version)
self.cache(self.get_cache_key(V))._dat_versions[key] = dat_versions
self.cache(V)._dat_versions[key] = dat_versions

@PETSc.Log.EventDecorator()
def op(self, source, target, transfer_op):
Expand Down
8 changes: 4 additions & 4 deletions firedrake/preconditioners/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,18 +585,18 @@ def bcdofs(bc, ghost=True):
if ghost:
offset += sum(Z.sub(j).dof_count for j in range(idx))
else:
offset += sum(Z.sub(j).dof_dset.size * Z.sub(j).value_size for j in range(idx))
offset += sum(Z.sub(j).dof_dset.size * Z.sub(j).block_size for j in range(idx))
else:
raise NotImplementedError("How are you taking a .sub?")

Z = Z.sub(idx)

if Z.parent is not None and isinstance(Z.parent.ufl_element(), VectorElement):
bs = Z.parent.value_size
bs = Z.parent.block_size
start = 0
stop = 1
else:
bs = Z.value_size
bs = Z.block_size
start = 0
stop = bs
nodes = bc.nodes
Expand Down Expand Up @@ -868,7 +868,7 @@ def initialize(self, obj):
offsets = numpy.append([0], numpy.cumsum([W.dof_count
for W in V])).astype(PETSc.IntType)
patch.setPatchDiscretisationInfo([W.dm for W in V],
numpy.array([W.value_size for
numpy.array([W.block_size for
W in V], dtype=PETSc.IntType),
[W.cell_node_list for W in V],
offsets,
Expand Down
4 changes: 2 additions & 2 deletions firedrake/preconditioners/pmg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,8 +1375,8 @@ def make_blas_kernels(self, Vf, Vc):
# We could benefit from loop tiling for the transpose, but that makes the code
# more complicated.

fshape = (numpy.prod(Vf.shape), Vf.finat_element.space_dimension())
cshape = (numpy.prod(Vc.shape), Vc.finat_element.space_dimension())
fshape = (Vf.block_size, Vf.finat_element.space_dimension())
cshape = (Vc.block_size, Vc.finat_element.space_dimension())

lwork = numpy.prod([max(*dims) for dims in zip(*shapes)])
lwork = max(lwork, max(numpy.prod(fshape), numpy.prod(cshape)))
Expand Down

0 comments on commit 38e7f78

Please sign in to comment.