Skip to content

Commit

Permalink
Fix assembly of Real matrices (#3846)
Browse files Browse the repository at this point in the history
* Fix assembly of Real matrices
  • Loading branch information
connorjward authored Nov 13, 2024
1 parent d12e3ac commit a59b15f
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 76 deletions.
157 changes: 81 additions & 76 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from firedrake.utils import ScalarType, assert_empty, tuplify
from pyop2 import op2
from pyop2.exceptions import MapValueError, SparsityFormatError
from pyop2.types.mat import _GlobalMatPayload, _DatMatPayload
from pyop2.utils import cached_property


Expand Down Expand Up @@ -965,22 +966,24 @@ def assemble(self, tensor=None):
Result of assembly: `float` for 0-forms, `firedrake.cofunction.Cofunction` or `firedrake.function.Function` for 1-forms, and `matrix.MatrixBase` for 2-forms.
"""
self._check_tensor(tensor)
if tensor is None:
tensor = self.allocate()
needs_zeroing = False
else:
needs_zeroing = self._needs_zeroing
if annotate_tape():
raise NotImplementedError(
"Taping with explicit FormAssembler objects is not supported yet. "
"Use assemble instead."
)
if needs_zeroing:
type(self)._as_pyop2_type(tensor).zero()

if tensor is None:
tensor = self.allocate()
else:
self._check_tensor(tensor)
if self._needs_zeroing:
self._as_pyop2_type(tensor).zero()

self.execute_parloops(tensor)

for bc in self._bcs:
self._apply_bc(tensor, bc)

return self.result(tensor)

@abc.abstractmethod
Expand All @@ -992,9 +995,9 @@ def _check_tensor(self, tensor):
"""Check input tensor."""

@staticmethod
def _as_pyop2_type(tensor):
"""Return tensor as pyop2 type."""
raise NotImplementedError
@abc.abstractmethod
def _as_pyop2_type(tensor, indices=None):
"""Cast a Firedrake tensor into a PyOP2 data structure, optionally indexing it."""

def execute_parloops(self, tensor):
for parloop in self.parloops(tensor):
Expand All @@ -1003,29 +1006,27 @@ def execute_parloops(self, tensor):
def parloops(self, tensor):
if hasattr(self, "_parloops"):
for (lknl, _), parloop in zip(self.local_kernels, self._parloops):
data = _FormHandler.index_tensor(tensor, self._form, lknl.indices, self.diagonal)
data = self._as_pyop2_type(tensor, lknl.indices)
parloop.arguments[0].data = data

else:
# Make parloops for one concrete output tensor and cache them.
# TODO: Make parloops only with some symbolic information of the output tensor.
self._parloops = tuple(parloop_builder.build(tensor) for parloop_builder in self.parloop_builders)
return self._parloops

@cached_property
def parloop_builders(self):
out = []
for local_kernel, subdomain_id in self.local_kernels:
out.append(
ParloopBuilder(
parloops_ = []
for local_kernel, subdomain_id in self.local_kernels:
parloop_builder = ParloopBuilder(
self._form,
self._bcs,
local_kernel,
subdomain_id,
self.all_integer_subdomain_ids,
diagonal=self.diagonal,
)
)
return tuple(out)
pyop2_tensor = self._as_pyop2_type(tensor, local_kernel.indices)
parloop = parloop_builder.build(pyop2_tensor)
parloops_.append(parloop)
self._parloops = tuple(parloops_)

return self._parloops

@cached_property
def local_kernels(self):
Expand Down Expand Up @@ -1120,10 +1121,11 @@ def _apply_bc(self, tensor, bc):
pass

def _check_tensor(self, tensor):
assert tensor is None
pass

@staticmethod
def _as_pyop2_type(tensor):
def _as_pyop2_type(tensor, indices=None):
assert not indices
return tensor

def result(self, tensor):
Expand Down Expand Up @@ -1198,15 +1200,16 @@ def _apply_dirichlet_bc(self, tensor, bc):
bc.zero(tensor)

def _check_tensor(self, tensor):
rank = len(self._form.arguments())
if rank == 1:
test, = self._form.arguments()
if tensor is not None and test.function_space() != tensor.function_space():
raise ValueError("Form's argument does not match provided result tensor")
if tensor.function_space() != self._form.arguments()[0].function_space():
raise ValueError("Form's argument does not match provided result tensor")

@staticmethod
def _as_pyop2_type(tensor):
return tensor.dat
def _as_pyop2_type(tensor, indices=None):
if indices is not None and any(index is not None for index in indices):
i, = indices
return tensor.dat[i]
else:
return tensor.dat

def execute_parloops(self, tensor):
# We are repeatedly incrementing into the same Dat so intermediate halo exchanges
Expand Down Expand Up @@ -1454,12 +1457,26 @@ def _apply_bcs_mat_real_block(op2tensor, i, j, component, node_set):
dat.zero(subset=node_set)

def _check_tensor(self, tensor):
if tensor is not None and tensor.a.arguments() != self._form.arguments():
if tensor.a.arguments() != self._form.arguments():
raise ValueError("Form's arguments do not match provided result tensor")

@staticmethod
def _as_pyop2_type(tensor):
return tensor.M
def _as_pyop2_type(tensor, indices=None):
if indices is not None and any(index is not None for index in indices):
i, j = indices
mat = tensor.M[i, j]
else:
mat = tensor.M

if mat.handle.getType() == "python":
mat_context = mat.handle.getPythonContext()
if isinstance(mat_context, _GlobalMatPayload):
mat = mat_context.global_
else:
assert isinstance(mat_context, _DatMatPayload)
mat = mat_context.dat

return mat

def result(self, tensor):
tensor.M.assemble()
Expand All @@ -1471,7 +1488,7 @@ class MatrixFreeAssembler(FormAssembler):
Parameters
----------
form : ufl.Form or slate.TensorBasehe
form : ufl.Form or slate.TensorBase
2-form.
Notes
Expand All @@ -1498,14 +1515,15 @@ def allocate(self):
appctx=self._appctx or {})

def assemble(self, tensor=None):
self._check_tensor(tensor)
if tensor is None:
tensor = self.allocate()
else:
self._check_tensor(tensor)
tensor.assemble()
return tensor

def _check_tensor(self, tensor):
if tensor is not None and tensor.a.arguments() != self._form.arguments():
if tensor.a.arguments() != self._form.arguments():
raise ValueError("Form's arguments do not match provided result tensor")


Expand Down Expand Up @@ -1820,12 +1838,12 @@ def __init__(self, form, bcs, local_knl, subdomain_id,
self._active_coefficients = _FormHandler.iter_active_coefficients(form, local_knl.kinfo)
self._constants = _FormHandler.iter_constants(form, local_knl.kinfo)

def build(self, tensor):
def build(self, tensor: op2.Global | op2.Dat | op2.Mat) -> op2.Parloop:
"""Construct the parloop.
Parameters
----------
tensor : op2.Global or firedrake.cofunction.Cofunction or matrix.MatrixBase
tensor :
The output tensor.
"""
Expand Down Expand Up @@ -1909,17 +1927,28 @@ def collect_lgmaps(self):
:param local_knl: A :class:`tsfc_interface.SplitKernel`.
:param bcs: Iterable of boundary conditions.
"""

if len(self._form.arguments()) == 2 and not self._diagonal:
if not self._bcs:
return None
lgmaps = []
for i, j in self.get_indicess():

if any(i is not None for i in self._local_knl.indices):
i, j = self._local_knl.indices
row_bcs, col_bcs = self._filter_bcs(i, j)
rlgmap, clgmap = self._tensor.M[i, j].local_to_global_maps
# the tensor is already indexed
rlgmap, clgmap = self._tensor.local_to_global_maps
rlgmap = self.test_function_space[i].local_to_global_map(row_bcs, rlgmap)
clgmap = self.trial_function_space[j].local_to_global_map(col_bcs, clgmap)
lgmaps.append((rlgmap, clgmap))
return tuple(lgmaps)
return ((rlgmap, clgmap),)
else:
lgmaps = []
for i, j in self.get_indicess():
row_bcs, col_bcs = self._filter_bcs(i, j)
rlgmap, clgmap = self._tensor[i, j].local_to_global_maps
rlgmap = self.test_function_space[i].local_to_global_map(row_bcs, rlgmap)
clgmap = self.trial_function_space[j].local_to_global_map(col_bcs, clgmap)
lgmaps.append((rlgmap, clgmap))
return tuple(lgmaps)
else:
return None

Expand All @@ -1939,10 +1968,6 @@ def _integral_type(self):
def _indexed_function_spaces(self):
return _FormHandler.index_function_spaces(self._form, self._indices)

@property
def _indexed_tensor(self):
return _FormHandler.index_tensor(self._tensor, self._form, self._indices, self._diagonal)

@cached_property
def _mesh(self):
return self._form.ufl_domains()[self._kinfo.domain_number]
Expand Down Expand Up @@ -1990,28 +2015,27 @@ def _as_parloop_arg(tsfc_arg, self):
@_as_parloop_arg.register(kernel_args.OutputKernelArg)
def _as_parloop_arg_output(_, self):
rank = len(self._form.arguments())
tensor = self._indexed_tensor
Vs = self._indexed_function_spaces

if rank == 0:
return op2.GlobalParloopArg(tensor)
return op2.GlobalParloopArg(self._tensor)
elif rank == 1 or rank == 2 and self._diagonal:
V, = Vs
if V.ufl_element().family() == "Real":
return op2.GlobalParloopArg(tensor)
return op2.GlobalParloopArg(self._tensor)
else:
return op2.DatParloopArg(tensor, self._get_map(V))
return op2.DatParloopArg(self._tensor, self._get_map(V))
elif rank == 2:
rmap, cmap = [self._get_map(V) for V in Vs]

if all(V.ufl_element().family() == "Real" for V in Vs):
assert rmap is None and cmap is None
return op2.GlobalParloopArg(tensor.handle.getPythonContext().global_)
return op2.GlobalParloopArg(self._tensor)
elif any(V.ufl_element().family() == "Real" for V in Vs):
m = rmap or cmap
return op2.DatParloopArg(tensor.handle.getPythonContext().dat, m)
return op2.DatParloopArg(self._tensor, m)
else:
return op2.MatParloopArg(tensor, (rmap, cmap), lgmaps=self.collect_lgmaps())
return op2.MatParloopArg(self._tensor, (rmap, cmap), lgmaps=self.collect_lgmaps())
else:
raise AssertionError

Expand Down Expand Up @@ -2122,22 +2146,3 @@ def index_function_spaces(form, indices):
return tuple(a.ufl_function_space()[i] for i, a in zip(indices, form.arguments()))
else:
raise AssertionError

@staticmethod
def index_tensor(tensor, form, indices, diagonal):
"""Return the PyOP2 data structure tied to ``tensor``, indexed
if necessary.
"""
rank = len(form.arguments())
is_indexed = any(i is not None for i in indices)

if rank == 0:
return tensor
elif rank == 1 or rank == 2 and diagonal:
i, = indices
return tensor.dat[i] if is_indexed else tensor.dat
elif rank == 2:
i, j = indices
return tensor.M[i, j] if is_indexed else tensor.M
else:
raise AssertionError
18 changes: 18 additions & 0 deletions tests/regression/test_assemble.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import numpy as np
from firedrake import *
from firedrake.assemble import TwoFormAssembler
from firedrake.utils import ScalarType, IntType


Expand Down Expand Up @@ -125,6 +126,23 @@ def test_assemble_mat_with_tensor(mesh):
assert np.allclose(M.M.values, 2*assemble(a).M.values, rtol=1e-14)


@pytest.mark.skipcomplex
def test_mat_nest_real_block_assembler_correctly_reuses_tensor(mesh):
V = FunctionSpace(mesh, "CG", 1)
R = FunctionSpace(mesh, "R", 0)
W = V * R

u = TrialFunction(W)
v = TestFunction(W)
a = inner(v, u) * dx

assembler = TwoFormAssembler(a, mat_type="nest")
A1 = assembler.assemble()
A2 = assembler.assemble(tensor=A1)

assert A2.M is A1.M


def test_assemble_diagonal(mesh):
V = FunctionSpace(mesh, "P", 3)
u = TrialFunction(V)
Expand Down

0 comments on commit a59b15f

Please sign in to comment.