diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 05448fb7b2..1a26285904 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -988,18 +988,16 @@ def callable(): else: # Make sure we have an expression of the right length i.e. a value for # each component in the value shape of each function space - dims = [numpy.prod(fs.value_shape, dtype=int) - for fs in V] loops = [] - if numpy.prod(expr.ufl_shape, dtype=int) != sum(dims): + if numpy.prod(expr.ufl_shape, dtype=int) != V.value_size: raise RuntimeError('Expression of length %d required, got length %d' - % (sum(dims), numpy.prod(expr.ufl_shape, dtype=int))) + % (V.value_size, numpy.prod(expr.ufl_shape, dtype=int))) if len(V) > 1: raise NotImplementedError( "UFL expressions for mixed functions are not yet supported.") loops.extend(_interpolator(V, tensor, expr, subset, arguments, access, bcs=bcs)) if bcs and len(arguments) == 0: - loops.extend([partial(bc.apply, f) for bc in bcs]) + loops.extend(partial(bc.apply, f) for bc in bcs) def callable(loops, f): for l in loops: diff --git a/firedrake/mg/embedded.py b/firedrake/mg/embedded.py index d6a3b72a94..984ecff8a0 100644 --- a/firedrake/mg/embedded.py +++ b/firedrake/mg/embedded.py @@ -40,8 +40,8 @@ class Cache(object): """A caching object for work vectors and matrices. :arg element: The element to use for the caching.""" - def __init__(self, key): - self.embedding_element = get_embedding_dg_element(*key) + def __init__(self, ufl_element, value_shape): + self.embedding_element = get_embedding_dg_element(ufl_element, value_shape) self._dat_versions = {} self._V_DG_mass = {} self._DG_inv_mass = {} @@ -83,16 +83,12 @@ def _native_transfer(self, element, op): return self.native_transfers.setdefault(element, ops)[op] return None - def cache(self, key): + def cache(self, V): + key = (V.ufl_element(), V.value_shape) try: return self.caches[key] except KeyError: - return self.caches.setdefault(key, TransferManager.Cache(key)) - - def get_cache_key(self, V): - elem = V.ufl_element() - value_shape = V.value_shape - return elem, value_shape + return self.caches.setdefault(key, TransferManager.Cache(*key)) def V_dof_weights(self, V): """Dof weights for averaging projection. @@ -100,7 +96,7 @@ def V_dof_weights(self, V): :arg V: function space to compute weights for. :returns: A PETSc Vec. """ - cache = self.cache(self.get_cache_key(V)) + cache = self.cache(V) key = V.dim() try: return cache._V_dof_weights[key] @@ -125,7 +121,7 @@ def V_DG_mass(self, V, DG): :arg DG: the DG space :returns: A PETSc Mat mapping from V -> DG """ - cache = self.cache(self.get_cache_key(V)) + cache = self.cache(V) key = V.dim() try: return cache._V_DG_mass[key] @@ -140,7 +136,7 @@ def DG_inv_mass(self, DG): :arg DG: the DG space :returns: A PETSc Mat. """ - cache = self.cache(self.get_cache_key(DG)) + cache = self.cache(DG) key = DG.dim() try: return cache._DG_inv_mass[key] @@ -156,7 +152,7 @@ def V_approx_inv_mass(self, V, DG): :arg DG: the DG space :returns: A PETSc Mat mapping from V -> DG. """ - cache = self.cache(self.get_cache_key(V)) + cache = self.cache(DG) key = V.dim() try: return cache._V_approx_inv_mass[key] @@ -174,7 +170,7 @@ def V_inv_mass_ksp(self, V): :arg V: a function space. :returns: A PETSc KSP for inverting (V, V). """ - cache = self.cache(self.get_cache_key(V)) + cache = self.cache(V) key = V.dim() try: return cache._V_inv_mass_ksp[key] @@ -196,7 +192,7 @@ def DG_work(self, V): :returns: A Function in the embedding DG space. """ needs_dual = ufl.duals.is_dual(V) - cache = self.cache(self.get_cache_key(V)) + cache = self.cache(V) key = (V.dim(), needs_dual) try: return cache._DG_work[key] @@ -213,7 +209,7 @@ def work_vec(self, V): :arg V: a function space. :returns: A PETSc Vec for V. """ - cache = self.cache(self.get_cache_key(V)) + cache = self.cache(V) key = V.dim() try: return cache._work_vec[key] @@ -226,7 +222,7 @@ def requires_transfer(self, V, transfer_op, source, target): key = (transfer_op, weakref.ref(source.dat), weakref.ref(target.dat)) dat_versions = (source.dat.dat_version, target.dat.dat_version) try: - return self.cache(self.get_cache_key(V))._dat_versions[key] != dat_versions + return self.cache(V)._dat_versions[key] != dat_versions except KeyError: return True @@ -234,7 +230,7 @@ def cache_dat_versions(self, V, transfer_op, source, target): """Record the returned dat_versions of the source and target.""" key = (transfer_op, weakref.ref(source.dat), weakref.ref(target.dat)) dat_versions = (source.dat.dat_version, target.dat.dat_version) - self.cache(self.get_cache_key(V))._dat_versions[key] = dat_versions + self.cache(V)._dat_versions[key] = dat_versions @PETSc.Log.EventDecorator() def op(self, source, target, transfer_op): diff --git a/firedrake/mg/kernels.py b/firedrake/mg/kernels.py index fc32f0c869..5405b6d726 100644 --- a/firedrake/mg/kernels.py +++ b/firedrake/mg/kernels.py @@ -283,7 +283,7 @@ def prolong_kernel(expression): "evaluate": eval_code, "spacedim": element.cell.get_spatial_dimension(), "ncandidate": hierarchy.fine_to_coarse_cells[levelf].shape[1], - "Rdim": numpy.prod(V.value_shape), + "Rdim": V.value_size, "inside_cell": inside_check(element.cell, eps=1e-8, X="Xref"), "celldist_l1_c_expr": celldist_l1_c_expr(element.cell, X="Xref"), "Xc_cell_inc": coords_element.space_dimension(), diff --git a/firedrake/mg/utils.py b/firedrake/mg/utils.py index 02f0700b2d..80e5e820f6 100644 --- a/firedrake/mg/utils.py +++ b/firedrake/mg/utils.py @@ -146,7 +146,7 @@ def physical_node_locations(V): try: return cache[key] except KeyError: - Vc = firedrake.FunctionSpace(mesh, finat.ufl.VectorElement(element)) + Vc = firedrake.VectorFunctionSpace(mesh, element) # FIXME: This is unsafe for DG coordinates and CG target spaces. locations = firedrake.assemble(firedrake.Interpolate(firedrake.SpatialCoordinate(mesh), Vc)) return cache.setdefault(key, locations) diff --git a/firedrake/preconditioners/patch.py b/firedrake/preconditioners/patch.py index f7a14c1e2f..0a7bad5575 100644 --- a/firedrake/preconditioners/patch.py +++ b/firedrake/preconditioners/patch.py @@ -585,18 +585,18 @@ def bcdofs(bc, ghost=True): if ghost: offset += sum(Z.sub(j).dof_count for j in range(idx)) else: - offset += sum(Z.sub(j).dof_dset.size * Z.sub(j).value_size for j in range(idx)) + offset += sum(Z.sub(j).dof_dset.size * Z.sub(j).block_size for j in range(idx)) else: raise NotImplementedError("How are you taking a .sub?") Z = Z.sub(idx) if Z.parent is not None and isinstance(Z.parent.ufl_element(), VectorElement): - bs = Z.parent.value_size + bs = Z.parent.block_size start = 0 stop = 1 else: - bs = Z.value_size + bs = Z.block_size start = 0 stop = bs nodes = bc.nodes @@ -868,7 +868,7 @@ def initialize(self, obj): offsets = numpy.append([0], numpy.cumsum([W.dof_count for W in V])).astype(PETSc.IntType) patch.setPatchDiscretisationInfo([W.dm for W in V], - numpy.array([W.value_size for + numpy.array([W.block_size for W in V], dtype=PETSc.IntType), [W.cell_node_list for W in V], offsets, diff --git a/firedrake/preconditioners/pmg.py b/firedrake/preconditioners/pmg.py index ab6de8c6b7..251f585cbe 100644 --- a/firedrake/preconditioners/pmg.py +++ b/firedrake/preconditioners/pmg.py @@ -1375,8 +1375,8 @@ def make_blas_kernels(self, Vf, Vc): # We could benefit from loop tiling for the transpose, but that makes the code # more complicated. - fshape = (numpy.prod(Vf.shape), Vf.finat_element.space_dimension()) - cshape = (numpy.prod(Vc.shape), Vc.finat_element.space_dimension()) + fshape = (Vf.block_size, Vf.finat_element.space_dimension()) + cshape = (Vc.block_size, Vc.finat_element.space_dimension()) lwork = numpy.prod([max(*dims) for dims in zip(*shapes)]) lwork = max(lwork, max(numpy.prod(fshape), numpy.prod(cshape))) diff --git a/tests/multigrid/test_non_nested.py b/tests/multigrid/test_non_nested.py index 8a68b18b7a..5dfb63ff8b 100644 --- a/tests/multigrid/test_non_nested.py +++ b/tests/multigrid/test_non_nested.py @@ -59,7 +59,7 @@ def test_sphere_mg(): u = TrialFunction(V) v = TestFunction(V) - a = (inner(u, v) + inner(u, v))*dx + a = (inner(grad(u), grad(v)) + inner(u, v))*dx f1 = exp((x+y+z)/R)*x*y*z/R**3 F = inner(f1, v)*dx @@ -91,4 +91,4 @@ def test_sphere_mg(): prob = LinearVariationalProblem(a, F, w) solver = LinearVariationalSolver(prob, solver_parameters=mg_params) solver.solve() - assert solver.snes.ksp.getIterationNumber() < 5 + assert solver.snes.ksp.getIterationNumber() < 7