Skip to content

Commit

Permalink
Get value_shape from FunctionSpace
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Nov 13, 2024
1 parent 98e4ea3 commit a2537dc
Show file tree
Hide file tree
Showing 27 changed files with 56 additions and 68 deletions.
4 changes: 2 additions & 2 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,12 +348,12 @@ def function_arg(self, g):
raise RuntimeError("%r is defined on incompatible FunctionSpace!" % g)
self._function_arg = g
elif isinstance(g, ufl.classes.Zero):
if g.ufl_shape and g.ufl_shape != V.ufl_element().value_shape(V.mesh()):
if g.ufl_shape and g.ufl_shape != V.value_shape:
raise ValueError(f"Provided boundary value {g} does not match shape of space")
# Special case. Scalar zero for direct Function.assign.
self._function_arg = g
elif isinstance(g, ufl.classes.Expr):
if g.ufl_shape != V.ufl_element().value_shape(V.mesh()):
if g.ufl_shape != V.value_shape:
raise RuntimeError(f"Provided boundary value {g} does not match shape of space")
try:
self._function_arg = firedrake.Function(V)
Expand Down
4 changes: 2 additions & 2 deletions firedrake/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ def save_function(self, f, idx=None, name=None, timestepping_info={}):
self._update_function_name_function_space_name_map(tmesh.name, mesh.name, {f.name(): V_name})
# Embed if necessary
element = V.ufl_element()
_element = get_embedding_element_for_checkpointing(element, element.value_shape(mesh))
_element = get_embedding_element_for_checkpointing(element, V.value_shape)
if _element != element:
path = self._path_to_function_embedded(tmesh.name, mesh.name, V_name, f.name())
self.require_group(path)
Expand Down Expand Up @@ -1337,7 +1337,7 @@ def load_function(self, mesh, name, idx=None):
_name = self.get_attr(path, PREFIX_EMBEDDED + "_function")
_f = self.load_function(mesh, _name, idx=idx)
element = V.ufl_element()
_element = get_embedding_element_for_checkpointing(element, element.value_shape(mesh))
_element = get_embedding_element_for_checkpointing(element, V.value_shape)
method = get_embedding_method_for_checkpointing(element)
assert _element == _f.function_space().ufl_element()
f = Function(V, name=name)
Expand Down
2 changes: 1 addition & 1 deletion firedrake/external_operators/point_expr_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, *operands, function_space, derivatives=None, argument_slots=(
if not isinstance(operator_data["func"], types.FunctionType):
raise TypeError("Expecting a FunctionType pointwise expression")
expr_shape = operator_data["func"](*operands).ufl_shape
if expr_shape != function_space.ufl_element().value_shape(function_space.mesh()):
if expr_shape != function_space.value_shape:
raise ValueError("The dimension does not match with the dimension of the function space %s" % function_space)

@property
Expand Down
3 changes: 1 addition & 2 deletions firedrake/formmanipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,7 @@ def argument(self, o):
args += [a_[j] for j in numpy.ndindex(a_.ufl_shape)]
else:
args += [Zero()
for j in numpy.ndindex(
V_is[i].ufl_element().value_shape(V_is[i].mesh()))]
for j in numpy.ndindex(V_is[i].value_shape)]
return self._arg_cache.setdefault(o, as_vector(args))


Expand Down
9 changes: 3 additions & 6 deletions firedrake/functionspaceimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,13 @@ def split(self):
def _components(self):
if len(self) == 1:
return tuple(type(self).create(self.topological.sub(i), self.mesh())
for i in range(self.value_size))
for i in range(numpy.prod(self.shape)))
else:
return self.subfunctions

@PETSc.Log.EventDecorator()
def sub(self, i):
if len(self) == 1:
bound = self.value_size
else:
bound = len(self)
bound = len(self._components)
if i < 0 or i >= bound:
raise IndexError("Invalid component %d, not in [0, %d)" % (i, bound))
return self._components[i]
Expand Down Expand Up @@ -654,7 +651,7 @@ def __getitem__(self, i):

@utils.cached_property
def _components(self):
return tuple(ComponentFunctionSpace(self, i) for i in range(self.value_size))
return tuple(ComponentFunctionSpace(self, i) for i in range(numpy.prod(self.shape)))

def sub(self, i):
r"""Return a view into the ith component."""
Expand Down
13 changes: 6 additions & 7 deletions firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,8 +546,7 @@ def __init__(
# For a VectorElement or TensorElement the correct
# VectorFunctionSpace equivalent is built from the scalar
# sub-element.
ufl_scalar_element = ufl_scalar_element.sub_elements[0]
if ufl_scalar_element.value_shape(V_dest.mesh()) != ():
if V_dest.sub(0).value_shape != ():
raise NotImplementedError(
"Can't yet cross-mesh interpolate onto function spaces made from VectorElements or TensorElements made from sub elements with value shape other than ()."
)
Expand Down Expand Up @@ -614,7 +613,7 @@ def __init__(
# I first point evaluate my expression at these locations, giving a
# P0DG function on the VOM. As described in the manual, this is an
# interpolation operation.
shape = V_dest.ufl_element().value_shape(V_dest.mesh())
shape = V_dest.value_shape
if len(shape) == 0:
fs_type = firedrake.FunctionSpace
elif len(shape) == 1:
Expand Down Expand Up @@ -988,7 +987,7 @@ def callable():
else:
# Make sure we have an expression of the right length i.e. a value for
# each component in the value shape of each function space
dims = [numpy.prod(fs.ufl_element().value_shape(fs.mesh()), dtype=int)
dims = [numpy.prod(fs.value_shape, dtype=int)
for fs in V]
loops = []
if numpy.prod(expr.ufl_shape, dtype=int) != sum(dims):
Expand Down Expand Up @@ -1024,11 +1023,11 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
if access is op2.READ:
raise ValueError("Can't have READ access for output function")

if len(expr.ufl_shape) != len(V.ufl_element().value_shape(V.mesh())):
if len(expr.ufl_shape) != len(V.value_shape):
raise RuntimeError('Rank mismatch: Expression rank %d, FunctionSpace rank %d'
% (len(expr.ufl_shape), len(V.ufl_element().value_shape(V.mesh()))))
% (len(expr.ufl_shape), len(V.value_shape)))

if expr.ufl_shape != V.ufl_element().value_shape(V.mesh()):
if expr.ufl_shape != V.value_shape:
raise RuntimeError('Shape mismatch: Expression shape %r, FunctionSpace shape %r'
% (expr.ufl_shape, V.ufl_element().value_shape(V.mesh())))

Expand Down
2 changes: 1 addition & 1 deletion firedrake/mg/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def cache(self, key):

def get_cache_key(self, V):
elem = V.ufl_element()
value_shape = elem.value_shape(V.mesh())
value_shape = V.value_shape
return elem, value_shape

def V_dof_weights(self, V):
Expand Down
19 changes: 9 additions & 10 deletions firedrake/mg/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,10 @@ def compile_element(expression, dual_space=None, parameters=None,
return_variable = gem.Indexed(gem.Variable('R', finat_elem.index_shape), argument_multiindex)
result = gem.Indexed(result, tensor_indices)
if dual_space:
elem = create_element(dual_space.ufl_element())
if elem.value_shape:
var = gem.Indexed(gem.Variable("b", elem.value_shape),
tensor_indices)
b_arg = [lp.GlobalArg("b", dtype=ScalarType, shape=elem.value_shape)]
value_shape = dual_space.value_shape
if value_shape:
var = gem.Indexed(gem.Variable("b", value_shape), tensor_indices)
b_arg = [lp.GlobalArg("b", dtype=ScalarType, shape=value_shape)]
else:
var = gem.Indexed(gem.Variable("b", (1, )), (0, ))
b_arg = [lp.GlobalArg("b", dtype=ScalarType, shape=(1,))]
Expand Down Expand Up @@ -220,7 +219,7 @@ def prolong_kernel(expression):
assert hierarchy._meshes[int(idx)].cell_set._extruded
V = expression.function_space()
key = (("prolong",)
+ expression.ufl_element().value_shape(meshc)
+ V.value_shape
+ entity_dofs_key(V.finat_element.complex.get_topology())
+ entity_dofs_key(V.finat_element.entity_dofs())
+ entity_dofs_key(coordinates.function_space().finat_element.entity_dofs()))
Expand Down Expand Up @@ -284,7 +283,7 @@ def prolong_kernel(expression):
"evaluate": eval_code,
"spacedim": element.cell.get_spatial_dimension(),
"ncandidate": hierarchy.fine_to_coarse_cells[levelf].shape[1],
"Rdim": numpy.prod(element.value_shape),
"Rdim": numpy.prod(V.value_shape),
"inside_cell": inside_check(element.cell, eps=1e-8, X="Xref"),
"celldist_l1_c_expr": celldist_l1_c_expr(element.cell, X="Xref"),
"Xc_cell_inc": coords_element.space_dimension(),
Expand All @@ -302,7 +301,7 @@ def restrict_kernel(Vf, Vc):
if Vf.extruded:
assert Vc.extruded
key = (("restrict",)
+ Vf.ufl_element().value_shape(Vf.mesh())
+ Vf.value_shape
+ entity_dofs_key(Vf.finat_element.complex.get_topology())
+ entity_dofs_key(Vc.finat_element.complex.get_topology())
+ entity_dofs_key(Vf.finat_element.entity_dofs())
Expand Down Expand Up @@ -390,7 +389,7 @@ def inject_kernel(Vf, Vc):
else:
level_ratio = 1
key = (("inject", level_ratio)
+ Vf.ufl_element().value_shape(Vf.mesh())
+ Vf.value_shape
+ entity_dofs_key(Vc.finat_element.complex.get_topology())
+ entity_dofs_key(Vf.finat_element.complex.get_topology())
+ entity_dofs_key(Vc.finat_element.entity_dofs())
Expand Down Expand Up @@ -465,7 +464,7 @@ def inject_kernel(Vf, Vc):
"celldist_l1_c_expr": celldist_l1_c_expr(Vc.finat_element.cell, X="Xref"),
"tdim": Vc.mesh().topological_dimension(),
"ncandidate": ncandidate,
"Rdim": numpy.prod(Vf_element.value_shape),
"Rdim": numpy.prod(Vf.value_shape),
"Xf_cell_inc": coords_element.space_dimension(),
"f_cell_inc": Vf_element.space_dimension()
}
Expand Down
2 changes: 1 addition & 1 deletion firedrake/mg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def coarse_cell_to_fine_node_map(Vc, Vf):

def physical_node_locations(V):
element = V.ufl_element()
if element.value_shape(V.mesh()):
if V.value_shape:
assert isinstance(element, (finat.ufl.VectorElement, finat.ufl.TensorElement))
element = element.sub_elements[0]
mesh = V.mesh()
Expand Down
10 changes: 4 additions & 6 deletions firedrake/preconditioners/pmg.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,7 @@ def initialize(self, obj):
elements = [ele]
while True:
try:
ele_ = self.coarsen_element(ele)
assert ele_.value_shape(V.mesh()) == ele.value_shape(V.mesh())
ele = ele_
ele = self.coarsen_element(ele)
except ValueError:
break
elements.append(ele)
Expand Down Expand Up @@ -1098,7 +1096,7 @@ def make_mapping_code(Q, cmapping, fmapping, t_in, t_out):
if B:
tensor = ufl.dot(B, tensor) if tensor else B
if tensor is None:
tensor = ufl.Identity(Q.ufl_element().value_shape(Q.mesh())[0])
tensor = ufl.Identity(Q.value_shape[0])

u = ufl.Coefficient(Q)
expr = ufl.dot(tensor, u)
Expand Down Expand Up @@ -1347,8 +1345,8 @@ def make_blas_kernels(self, Vf, Vc):
in_place_mapping = True
except Exception:
qelem = finat.ufl.FiniteElement("DQ", cell=felem.cell, degree=PMGBase.max_degree(felem))
if felem.value_shape(Vf.mesh()):
qelem = finat.ufl.TensorElement(qelem, shape=felem.value_shape(Vf.mesh()), symmetry=felem.symmetry())
if Vf.value_shape:
qelem = finat.ufl.TensorElement(qelem, shape=Vf.value_shape, symmetry=felem.symmetry())
Qf = firedrake.FunctionSpace(Vf.mesh(), qelem)
mapping_output = make_mapping_code(Qf, cmapping, fmapping, "t0", "t1")

Expand Down
4 changes: 2 additions & 2 deletions firedrake/pyplot/pgf.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,12 @@ def pgfplot(f, filename, degree=1, complex_component='real', print_latex_example
raise NotImplementedError(f"Not yet implemented for functions in spatial dimension {dim}")
if mesh.extruded:
raise NotImplementedError("Not yet implemented for functions on extruded meshes")
if elem.value_shape(mesh):
if V.value_shape:
raise NotImplementedError("Currently only implemeted for scalar functions")
coordelem = get_embedding_dg_element(mesh.coordinates.function_space().ufl_element(), (dim, )).reconstruct(degree=degree, variant="equispaced")
coordV = FunctionSpace(mesh, coordelem)
coords = Function(coordV).interpolate(SpatialCoordinate(mesh))
elemdg = get_embedding_dg_element(elem, elem.value_shape(mesh)).reconstruct(degree=degree, variant="equispaced")
elemdg = get_embedding_dg_element(elem, V.value_shape).reconstruct(degree=degree, variant="equispaced")
Vdg = FunctionSpace(mesh, elemdg)
fdg = Function(Vdg)
method = get_embedding_method_for_checkpointing(elem)
Expand Down
2 changes: 1 addition & 1 deletion firedrake/slate/static_condensation/hybridization.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def initialize(self, pc):
if len(V) != 2:
raise ValueError("Expecting two function spaces.")

if all(Vi.ufl_element().value_shape(Vi.mesh()) for Vi in V):
if all(Vi.value_shape for Vi in V):
raise ValueError("Expecting an H(div) x L2 pair of spaces.")

# Automagically determine which spaces are vector and scalar
Expand Down
6 changes: 2 additions & 4 deletions firedrake/ufl_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ def reconstruct(self, function_space=None,
return self
if not isinstance(number, int):
raise TypeError(f"Expecting an int, not {number}")
mesh = self.function_space().mesh()
if function_space.ufl_element().value_shape(mesh) != self.ufl_element().value_shape(mesh):
if function_space.value_shape != self.function_space().value_shape:
raise ValueError("Cannot reconstruct an Argument with a different value shape.")
return Argument(function_space, number, part=part)

Expand Down Expand Up @@ -141,8 +140,7 @@ def reconstruct(self, function_space=None,
return self
if not isinstance(number, int):
raise TypeError(f"Expecting an int, not {number}")
mesh = self.function_space().mesh()
if function_space.ufl_element().value_shape(mesh) != self.ufl_element().value_shape(mesh):
if function_space.value_shape != self.function_space().value_shape:
raise ValueError("Cannot reconstruct an Coargument with a different value shape.")
return Coargument(function_space, number, part=part)

Expand Down
2 changes: 1 addition & 1 deletion tests/output/test_io_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _get_mesh(cell_type, comm):
def _get_expr(V):
mesh = V.mesh()
dim = mesh.geometric_dimension()
shape = V.ufl_element().value_shape(mesh)
shape = V.value_shape
if dim == 2:
x, y = SpatialCoordinate(mesh)
z = x * y
Expand Down
2 changes: 1 addition & 1 deletion tests/output/test_io_timestepping.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def _get_expr(V, i):
mesh = V.mesh()
element = V.ufl_element()
x, y = SpatialCoordinate(mesh)
shape = element.value_shape(mesh)
shape = V.value_shape
if element.family() == "Real":
return 7. + i * i
elif shape == ():
Expand Down
6 changes: 3 additions & 3 deletions tests/regression/test_ensembleparallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,13 @@ def test_ensemble_reduce(ensemble, mesh, W, urank, urank_sum, root, blocking):
parallel_assert(
lambda: error < 1e-12,
subset=root_ranks,
msg=f"{error = :.5f}"
msg=f"{error=:.5f}"
)
error = errornorm(Function(W).assign(10), u_reduce)
parallel_assert(
lambda: error < 1e-12,
subset={range(COMM_WORLD.size)} - root_ranks,
msg=f"{error = :.5f}"
msg=f"{error=:.5f}"
)

# check that u_reduce dat vector is still synchronised
Expand Down Expand Up @@ -347,7 +347,7 @@ def test_send_and_recv(ensemble, mesh, W, blocking):
parallel_assert(
lambda: error < 1e-12,
subset=root_ranks,
msg=f"{error = :.5f}"
msg=f"{error=:.5f}"
)


Expand Down
2 changes: 1 addition & 1 deletion tests/regression/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def test_assign_with_different_meshes_fails():
def test_assign_vector_const_to_vfs(vcg1):
f = Function(vcg1)

c = Constant(range(1, f.ufl_element().value_shape(vcg1.mesh())[0]+1))
c = Constant(range(1, f.function_space().value_shape[0]+1))

f.assign(c)
assert np.allclose(f.dat.data_ro, c.dat.data_ro)
Expand Down
3 changes: 1 addition & 2 deletions tests/regression/test_fas_snespatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,7 @@ def test_snespatch(mesh, CG1, solver_params):
f = Constant(1, domain=mesh)
F = inner(grad(u), grad(v))*dx - inner(f, v)*dx + inner(u**3 - u, v)*dx

z = zero(CG1.ufl_element().value_shape(mesh))
bcs = DirichletBC(CG1, z, "on_boundary")
bcs = DirichletBC(CG1, 0, "on_boundary")

nvproblem = NonlinearVariationalProblem(F, u, bcs=bcs)
solver = NonlinearVariationalSolver(nvproblem, solver_parameters=solver_params)
Expand Down
4 changes: 2 additions & 2 deletions tests/regression/test_fdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def build_riesz_map(V, d):

x = SpatialCoordinate(V.mesh())
x -= Constant([0.5]*len(x))
if V.ufl_element().value_shape(V.mesh()) == ():
if V.value_shape == ():
u_exact = exp(-10*dot(x, x))
u_bc = u_exact
else:
Expand Down Expand Up @@ -195,7 +195,7 @@ def test_variable_coefficient(mesh):
subs = ("on_boundary",)
if mesh.cell_set._extruded:
subs += ("top", "bottom")
bcs = [DirichletBC(V, zero(V.ufl_element().value_shape(mesh)), sub) for sub in subs]
bcs = [DirichletBC(V, 0, sub) for sub in subs]

uh = Function(V)
problem = LinearVariationalProblem(a, L, uh, bcs=bcs)
Expand Down
2 changes: 1 addition & 1 deletion tests/regression/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_mismatching_shape_interpolation(V):
VV = VectorFunctionSpace(V.mesh(), 'CG', 1)
f = Function(VV)
with pytest.raises(RuntimeError):
f.interpolate(Constant([1] * (VV.ufl_element().value_shape(VV.mesh())[0] + 1)))
f.interpolate(Constant([1] * (VV.value_shape[0] + 1)))


def test_function_val(V):
Expand Down
4 changes: 2 additions & 2 deletions tests/regression/test_function_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def test_reconstruct_component(space, dg0, rt1, mesh, mesh2, dual):
Z = {"dg0": dg0, "rt1": rt1}[space]
if dual:
Z = Z.dual()
for component in range(Z.value_size):
for component in range(len(Z)):
V1 = Z.sub(component)
V2 = V1.reconstruct(mesh=mesh2)
assert is_dual(V1) == is_dual(V2) == dual
Expand All @@ -293,7 +293,7 @@ def test_reconstruct_sub_component(dg0, rt1, mesh, mesh2, dual):
if dual:
Z = Z.dual()
for index, Vsub in enumerate(Z):
for component in range(Vsub.value_size):
for component in range(len(Vsub._components)):
V1 = Z.sub(index).sub(component)
V2 = V1.reconstruct(mesh=mesh2)
assert is_dual(V1) == is_dual(V2) == dual
Expand Down
2 changes: 1 addition & 1 deletion tests/regression/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def test_interpolation_tensor_convergence():
V = TensorFunctionSpace(mesh, "RT", 1)
x, y = SpatialCoordinate(mesh)

vs = V.ufl_element().value_shape(mesh)
vs = V.value_shape
expr = as_tensor(np.asarray([
sin(2*pi*x*(i+1))*cos(4*pi*y*i)
for i in range(np.prod(vs, dtype=int))
Expand Down
Loading

0 comments on commit a2537dc

Please sign in to comment.