diff --git a/tsfc/finatinterface.py b/tsfc/finatinterface.py index f9f461fc..2e524735 100644 --- a/tsfc/finatinterface.py +++ b/tsfc/finatinterface.py @@ -20,7 +20,7 @@ # along with FFC. If not, see . import weakref -from functools import partial, singledispatch +from functools import singledispatch import FIAT import finat @@ -124,6 +124,23 @@ def convert(element, **kwargs): raise ValueError("Unsupported element type %s" % type(element)) +cg_interval_variants = { + "fdm": finat.FDMLagrange, + "fdm_ipdg": finat.FDMLagrange, + "fdm_quadrature": finat.FDMQuadrature, + "fdm_broken": finat.FDMBrokenH1, + "fdm_hermite": finat.FDMHermite, +} + + +dg_interval_variants = { + "fdm": finat.FDMDiscontinuousLagrange, + "fdm_quadrature": finat.FDMDiscontinuousLagrange, + "fdm_ipdg": lambda *args: finat.DiscontinuousElement(finat.FDMLagrange(*args)), + "fdm_broken": finat.FDMBrokenL2, +} + + # Base finite elements first @convert.register(finat.ufl.FiniteElement) def convert_finiteelement(element, **kwargs): @@ -152,30 +169,19 @@ def convert_finiteelement(element, **kwargs): finat_elem, deps = _create_element(element, **kwargs) return finat.FlattenedDimensions(finat_elem), deps + kw = {} kind = element.variant() if kind is None: kind = 'spectral' # default variant - is_interval = element.cell.cellname() == 'interval' - if element.family() in {"Raviart-Thomas", "Nedelec 1st kind H(curl)", - "Brezzi-Douglas-Marini", "Nedelec 2nd kind H(curl)", - "Argyris"}: - lmbda = partial(lmbda, variant=element.variant()) - elif element.family() == "Lagrange": + if element.family() == "Lagrange": if kind == 'spectral': lmbda = finat.GaussLobattoLegendre - elif kind.startswith('integral'): - lmbda = partial(finat.IntegratedLegendre, variant=kind) - elif kind in ['fdm', 'fdm_ipdg'] and is_interval: - lmbda = finat.FDMLagrange - elif kind == 'fdm_quadrature' and is_interval: - lmbda = finat.FDMQuadrature - elif kind == 'fdm_broken' and is_interval: - lmbda = finat.FDMBrokenH1 - elif kind == 'fdm_hermite' and is_interval: - lmbda = finat.FDMHermite - elif kind in ['demkowicz', 'fdm']: - lmbda = partial(finat.IntegratedLegendre, variant=kind) + elif element.cell.cellname() == "interval" and kind in cg_interval_variants: + lmbda = cg_interval_variants[kind] + elif kind.startswith('integral') or kind in ['demkowicz', 'fdm']: + lmbda = finat.IntegratedLegendre + kw["variant"] = kind elif kind in ['mgd', 'feec', 'qb', 'mse']: degree = element.degree() shift_axes = kwargs["shift_axes"] @@ -184,20 +190,17 @@ def convert_finiteelement(element, **kwargs): return finat.RuntimeTabulated(cell, degree, variant=kind, shift_axes=shift_axes, restriction=restriction), deps else: # Let FIAT handle the general case - lmbda = partial(finat.Lagrange, variant=kind) + lmbda = finat.Lagrange + kw["variant"] = kind + elif element.family() in ["Discontinuous Lagrange", "Discontinuous Lagrange L2"]: if kind == 'spectral': lmbda = finat.GaussLegendre - elif kind.startswith('integral'): - lmbda = partial(finat.Legendre, variant=kind) - elif kind in ['fdm', 'fdm_quadrature'] and is_interval: - lmbda = finat.FDMDiscontinuousLagrange - elif kind == 'fdm_ipdg' and is_interval: - lmbda = lambda *args: finat.DiscontinuousElement(finat.FDMLagrange(*args)) - elif kind in 'fdm_broken' and is_interval: - lmbda = finat.FDMBrokenL2 - elif kind in ['demkowicz', 'fdm']: - lmbda = partial(finat.Legendre, variant=kind) + elif element.cell.cellname() == "interval" and kind in dg_interval_variants: + lmbda = dg_interval_variants[kind] + elif kind.startswith('integral') or kind in ['demkowicz', 'fdm']: + lmbda = finat.Legendre + kw["variant"] = kind elif kind in ['mgd', 'feec', 'qb', 'mse']: degree = element.degree() shift_axes = kwargs["shift_axes"] @@ -206,13 +209,13 @@ def convert_finiteelement(element, **kwargs): return finat.RuntimeTabulated(cell, degree, variant=kind, shift_axes=shift_axes, restriction=restriction, continuous=False), deps else: # Let FIAT handle the general case - lmbda = partial(finat.DiscontinuousLagrange, variant=kind) - elif element.family() == ["DPC", "DPC L2", "S"]: - dim = element.cell.geometric_dimension() - if dim > 1: - element = element.reconstruct(cell=ufl.cell.hypercube(dim)) + lmbda = finat.DiscontinuousLagrange + kw["variant"] = kind + + elif element.variant() is not None: + kw["variant"] = element.variant() - return lmbda(cell, element.degree()), set() + return lmbda(cell, element.degree(), **kw), set() # Element modifiers and compound element types