Skip to content

Commit

Permalink
use Kronecker.KroneckerProduct more sparingly
Browse files Browse the repository at this point in the history
  • Loading branch information
jlchan committed Jul 26, 2024
1 parent f39907c commit 5ae6977
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions src/RefElemData_polynomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -345,11 +345,11 @@ function RefElemData(elem::Quad,
M1D = Vq1D' * diagm(wq1D) * Vq1D

# form kronecker products of multidimensional matrices to invert/multiply
VDM = kronecker(VDM_1D, VDM_1D)
invVDM = kronecker(invVDM_1D, invVDM_1D)
invM = kronecker(invM_1D, invM_1D)
VDM = kron(VDM_1D, VDM_1D)
invVDM = kron(invVDM_1D, invVDM_1D)
invM = kron(invM_1D, invM_1D)

M = kronecker(M1D, M1D)
M = kron(M1D, M1D)

_, Vr, Vs = basis(elem, N, r, s)
Dr, Ds = (A -> A * invVDM).((Vr, Vs))
Expand All @@ -363,7 +363,7 @@ function RefElemData(elem::Quad,

# quadrature nodes - build from 1D nodes.
rq, sq, wq = tensor_product_quadrature(elem, approximation_type.data.quad_rule_1D...)
Vq = kronecker(Vq1D, Vq1D) # vandermonde(elem, N, rq, sq, tq) * invVDM
Vq = kron(Vq1D, Vq1D) # vandermonde(elem, N, rq, sq, tq) * invVDM
Pq = invM * (Vq' * diagm(wq))

Vf = vandermonde(elem, N, rf, sf) * invVDM
Expand All @@ -372,7 +372,7 @@ function RefElemData(elem::Quad,
# plotting nodes
rp1D = LinRange(-1, 1, Nplot + 1)
Vp1D = vandermonde(Line(), N, rp1D) / VDM_1D
Vp = kronecker(Vp1D, Vp1D)
Vp = kron(Vp1D, Vp1D)
rp, sp = vec.(StartUpDG.NodesAndModes.meshgrid(rp1D, rp1D))

return RefElemData(elem, approximation_type, N, fv, V1,
Expand Down Expand Up @@ -406,11 +406,14 @@ function RefElemData(elem::Hex,
M1D = Vq1D' * diagm(wq1D) * Vq1D

# form kronecker products of multidimensional matrices to invert/multiply
VDM = kronecker(VDM_1D, VDM_1D, VDM_1D)
invVDM = kronecker(invVDM_1D, invVDM_1D, invVDM_1D)
invM = kronecker(invM_1D, invM_1D, invM_1D)
# use dense matrix "kron" if N is 4 or lower; use memory-saving "kronecker" otherwise
build_kronecker_product = (N < 5) ? kron : kronecker

VDM = build_kronecker_product(VDM_1D, VDM_1D, VDM_1D)
invVDM = build_kronecker_product(invVDM_1D, invVDM_1D, invVDM_1D)
invM = build_kronecker_product(invM_1D, invM_1D, invM_1D)

M = kronecker(M1D, M1D, M1D)
M = build_kronecker_product(M1D, M1D, M1D)

_, Vr, Vs, Vt = basis(elem, N, r, s, t)
Dr, Ds, Dt = (A -> A * invVDM).((Vr, Vs, Vt))
Expand All @@ -424,7 +427,7 @@ function RefElemData(elem::Hex,

# quadrature nodes - build from 1D nodes.
rq, sq, tq, wq = tensor_product_quadrature(elem, approximation_type.data.quad_rule_1D...)
Vq = kronecker(Vq1D, Vq1D, Vq1D) # vandermonde(elem, N, rq, sq, tq) * invVDM
Vq = build_kronecker_product(Vq1D, Vq1D, Vq1D) # vandermonde(elem, N, rq, sq, tq) * invVDM
Pq = invM * (Vq' * diagm(wq))

Vf = vandermonde(elem, N, rf, sf, tf) * invVDM
Expand All @@ -433,7 +436,7 @@ function RefElemData(elem::Hex,
# plotting nodes
rp1D = LinRange(-1, 1, Nplot + 1)
Vp1D = vandermonde(Line(), N, rp1D) / VDM_1D
Vp = kronecker(Vp1D, Vp1D, Vp1D)
Vp = build_kronecker_product(Vp1D, Vp1D, Vp1D)
rp, sp, tp = vec.(StartUpDG.NodesAndModes.meshgrid(rp1D, rp1D, rp1D))

return RefElemData(elem, approximation_type, N, fv, V1,
Expand Down

0 comments on commit 5ae6977

Please sign in to comment.