Skip to content
This repository has been archived by the owner on Dec 6, 2024. It is now read-only.

Commit

Permalink
Let FIAT handle general CG/DG variants
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Mar 28, 2024
1 parent 90c20c5 commit 239ed12
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions tsfc/finatinterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,7 @@ def convert_finiteelement(element, **kwargs):
kind = 'spectral' # default variant

if element.family() == "Lagrange":
if kind == 'equispaced':
lmbda = finat.Lagrange
elif kind == 'spectral':
if kind == 'spectral':
lmbda = finat.GaussLobattoLegendre
elif kind == 'integral':
lmbda = finat.IntegratedLegendre
Expand All @@ -167,14 +165,13 @@ def convert_finiteelement(element, **kwargs):
deps = {"shift_axes", "restriction"}
return finat.RuntimeTabulated(cell, degree, variant=kind, shift_axes=shift_axes, restriction=restriction), deps
else:
raise ValueError("Variant %r not supported on %s" % (kind, element.cell))
# Let FIAT handle the general case
lmbda = finat.Lagrange
elif element.family() in {"Raviart-Thomas", "Nedelec 1st kind H(curl)",
"Brezzi-Douglas-Marini", "Nedelec 2nd kind H(curl)"}:
lmbda = partial(lmbda, variant=element.variant())
elif element.family() in ["Discontinuous Lagrange", "Discontinuous Lagrange L2"]:
if kind == 'equispaced':
lmbda = finat.DiscontinuousLagrange
elif kind == 'spectral':
if kind == 'spectral':
lmbda = finat.GaussLegendre
elif kind == 'integral':
lmbda = finat.Legendre
Expand All @@ -191,7 +188,8 @@ def convert_finiteelement(element, **kwargs):
deps = {"shift_axes", "restriction"}
return finat.RuntimeTabulated(cell, degree, variant=kind, shift_axes=shift_axes, restriction=restriction, continuous=False), deps
else:
raise ValueError("Variant %r not supported on %s" % (kind, element.cell))
# Let FIAT handle the general case
lmbda = finat.DiscontinuousLagrange
elif element.family() == ["DPC", "DPC L2"]:
if element.cell.geometric_dimension() == 2:
element = element.reconstruct(cell=ufl.cell.hypercube(2))
Expand Down

0 comments on commit 239ed12

Please sign in to comment.