Skip to content

Commit

Permalink
Add separate function
Browse files Browse the repository at this point in the history
  • Loading branch information
jpdean committed Jan 9, 2025
1 parent e217faf commit 6798def
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 61 deletions.
4 changes: 2 additions & 2 deletions python/demo/demo_mixed-topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
FunctionSpace,
assemble_matrix,
coordinate_element,
form,
mixed_topology_form,
)
from dolfinx.io.utils import cell_perm_vtk
from dolfinx.mesh import CellType, Mesh
Expand Down Expand Up @@ -104,7 +104,7 @@
a += [(ufl.inner(ufl.grad(u), ufl.grad(v)) - k**2 * u * v) * ufl.dx]

# Compile the form
a_form = form(a)
a_form = mixed_topology_form(a)

# Assemble the matrix
A = assemble_matrix(a_form)
Expand Down
2 changes: 2 additions & 0 deletions python/dolfinx/fem/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
extract_function_spaces,
form,
form_cpp_class,
mixed_topology_form,
)
from dolfinx.fem.function import (
Constant,
Expand Down Expand Up @@ -219,6 +220,7 @@ def compute_integration_domains(
"functionspace",
"locate_dofs_geometrical",
"locate_dofs_topological",
"mixed_topology_form",
"set_bc",
"transpose_dofmap",
]
144 changes: 85 additions & 59 deletions python/dolfinx/fem/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,90 @@ def form_cpp_class(
}


def mixed_topology_form(
forms: typing.Union[typing.Iterable[ufl.Form]],
dtype: npt.DTypeLike = default_scalar_type,
form_compiler_options: typing.Optional[dict] = None,
jit_options: typing.Optional[dict] = None,
entity_maps: typing.Optional[dict[Mesh, np.typing.NDArray[np.int32]]] = None,
):
"""
Create a mixed-topology from from an array of Forms.
# FIXME: This function is a temporary hack for mixed-topology meshes. It is needed
# because UFL does not know about mixed-topology meshes, so we need
# to pass a list of forms for each cell type.
Args:
form: A list of UFL forms. Each form should be the same, just
defined on different cell types.
dtype: Scalar type to use for the compiled form.
form_compiler_options: See :func:`ffcx_jit <dolfinx.jit.ffcx_jit>`
jit_options: See :func:`ffcx_jit <dolfinx.jit.ffcx_jit>`.
entity_maps: If any trial functions, test functions, or
coefficients in the form are not defined over the same mesh
as the integration domain, `entity_maps` must be supplied.
For each key (a mesh, different to the integration domain
mesh) a map should be provided relating the entities in the
integration domain mesh to the entities in the key mesh e.g.
for a key-value pair (msh, emap) in `entity_maps`, `emap[i]`
is the entity in `msh` corresponding to entity `i` in the
integration domain mesh.
Returns:
Compiled finite element Form.
"""

if form_compiler_options is None:
form_compiler_options = dict()

form_compiler_options["scalar_type"] = dtype
ftype = form_cpp_class(dtype)

# Extract subdomain data from UFL form
sd = forms[0].subdomain_data()
(domain,) = list(sd.keys()) # Assuming single domain

# Check that subdomain data for each integral type is the same
for data in sd.get(domain).values():
assert all([d is data[0] for d in data if d is not None])

mesh = domain.ufl_cargo()

ufcx_forms = []
modules = []
codes = []
for form in forms:
ufcx_form, module, code = jit.ffcx_jit(
mesh.comm,
form,
form_compiler_options=form_compiler_options,
jit_options=jit_options,
)
ufcx_forms.append(ufcx_form)
modules.append(module)
codes.append(code)

# In a mixed-topology mesh, each form has the same C++ function space,
# so we can extract it from any of them
V = [arg.ufl_function_space()._cpp_object for arg in form.arguments()]

coeffs = []
constants = []
subdomains = {}
entity_maps = {}
f = ftype(
[module.ffi.cast("uintptr_t", module.ffi.addressof(ufcx_form)) for ufcx_form in ufcx_forms],
V,
coeffs,
constants,
subdomains,
entity_maps,
mesh,
)
return Form(f, ufcx_forms, codes, modules)


def form(
form: typing.Union[ufl.Form, typing.Iterable[ufl.Form]],
dtype: npt.DTypeLike = default_scalar_type,
Expand Down Expand Up @@ -312,55 +396,6 @@ def _form(form):
)
return Form(f, ufcx_form, code, module)

# Temporary hack for mixed-topology meshes. This is needed because UFL
# does not know about mixed-topology meshes yet.
def _form_mixed_topo(forms):
# Extract subdomain data from UFL form
sd = forms[0].subdomain_data()
(domain,) = list(sd.keys()) # Assuming single domain

# Check that subdomain data for each integral type is the same
for data in sd.get(domain).values():
assert all([d is data[0] for d in data if d is not None])

mesh = domain.ufl_cargo()

ufcx_forms = []
modules = []
codes = []
for form in forms:
ufcx_form, module, code = jit.ffcx_jit(
mesh.comm,
form,
form_compiler_options=form_compiler_options,
jit_options=jit_options,
)
ufcx_forms.append(ufcx_form)
modules.append(module)
codes.append(code)

# In a mixed-topology mesh, each form has the same C++ function space,
# so we can extract it from any of them
V = [arg.ufl_function_space()._cpp_object for arg in form.arguments()]

coeffs = []
constants = []
subdomains = {}
entity_maps = {}
f = ftype(
[
module.ffi.cast("uintptr_t", module.ffi.addressof(ufcx_form))
for ufcx_form in ufcx_forms
],
V,
coeffs,
constants,
subdomains,
entity_maps,
mesh,
)
return Form(f, ufcx_forms, codes, modules)

def _create_form(form):
"""Recursively convert ufl.Forms to dolfinx.fem.Form.
Expand All @@ -377,16 +412,7 @@ def _create_form(form):
else:
return _form(form)
elif isinstance(form, collections.abc.Iterable):
# FIXME Temporary hack for mixed-topology meshes. This is needed
# because UFL does not know about mixed-topology meshes, so we need
# to pass a list of forms for each cell type.
sd = form[0].subdomain_data()
(domain,) = list(sd.keys())
mesh = domain.ufl_cargo()
if len(mesh.topology.cell_types) > 1:
return _form_mixed_topo(form)
else:
return list(map(lambda sub_form: _create_form(sub_form), form))
return list(map(lambda sub_form: _create_form(sub_form), form))
else:
return form

Expand Down

0 comments on commit 6798def

Please sign in to comment.