Skip to content

Commit

Permalink
only do most costly part on checkerboarded lattice
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Jun 15, 2024
1 parent 59c66f4 commit a686600
Showing 1 changed file with 28 additions and 12 deletions.
40 changes: 28 additions & 12 deletions lib/gpt/qcd/gauge/smear/local_stout.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,15 @@ def compute_adj_ab(A, B, C, generators, cache):
ein(C, g(g.adj(A)), g(B), *fgenerators)


def compute_adj_abc(A, B, C, V, generators, cache):
def compute_adj_abc(_A, _B, _C, _V, generators, cache, parity):

t = g.timer("compute_adj_abc")
t("checkerboarding")
A = g.pick_checkerboard(parity, _A)
B = g.pick_checkerboard(parity, _B)
C = g.pick_checkerboard(parity, _C)
V = g.pick_checkerboard(parity, _V)

t("other")
ng = len(generators)
tmp2 = {}
Expand All @@ -110,6 +117,11 @@ def compute_adj_abc(A, B, C, V, generators, cache):
tmp2[a,] = g(g.trace(C * D))
t("merge")
g.merge_color(V, tmp2)
t("checkerboarding")

_V[:] = 0
g.set_checkerboard(_V, V)

t()
# g.message(t)

Expand Down Expand Up @@ -372,6 +384,8 @@ def __call__(self, U):
def gradient(self, U, dU):
assert dU == U

cb = self.stout.params["checkerboard"]

cache_ab = {}
J_ac, NxxAd, Z_ac, M, fm, M_ab = self.stout.jacobian_components(U, cache_ab)

Expand Down Expand Up @@ -418,7 +432,8 @@ def gradient(self, U, dU):
PlaqR = g(M * fm)
FdetV = g.lattice(grid, adjoint_vector_otype)
cache = {}
compute_adj_abc(PlaqL, PlaqR, MpInvJx, FdetV, generators, cache)

compute_adj_abc(PlaqL, PlaqR, MpInvJx, FdetV, generators, cache, cb)

Fdet2_mu = g.copy(FdetV)
Fdet1_mu = g(0 * FdetV)
Expand Down Expand Up @@ -455,13 +470,14 @@ def gradient(self, U, dU):

PlaqR = g((-1.0) * PlaqR)
t("compute_adj_abc")
compute_adj_abc(PlaqL, PlaqR, MpInvJx, FdetV, generators, cache)

compute_adj_abc(PlaqL, PlaqR, MpInvJx, FdetV, generators, cache, cb)
t("non-local")
Fdet2_nu = g.copy(FdetV)

# + nw acw
PlaqR = g(rho * csf(U[nu], nu, csb(U[mu], mu, csb(U[nu], nu))))
PlaqL = csb(U_mu_masked, mu)
PlaqL = g(csb(U_mu_masked, mu))

dJdXe_nMpInv_y = g.cshift(dJdXe_nMpInv, mu, -1)
t("compute_adj_ab")
Expand All @@ -471,13 +487,13 @@ def gradient(self, U, dU):

MpInvJx_nu = g.cshift(MpInvJx, mu, -1)
t("compute_adj_abc")
compute_adj_abc(PlaqL, PlaqR, MpInvJx_nu, FdetV, generators, cache)
compute_adj_abc(PlaqL, PlaqR, MpInvJx_nu, FdetV, generators, cache, cb.inv())
t("non-local")
Fdet2_nu += FdetV

# - nu cw
PlaqL = g(rho * csf(U[mu], mu, csf(U[nu], nu, csb(U_mu_masked, mu))))
PlaqR = csf(U[nu], nu)
PlaqR = g(csf(U[nu], nu))

dJdXe_nMpInv_y = g.cshift(dJdXe_nMpInv, nu, 1)
t("compute_adj_ab")
Expand All @@ -487,7 +503,7 @@ def gradient(self, U, dU):

MpInvJx_nu = g.cshift(MpInvJx, nu, 1)
t("compute_adj_abc")
compute_adj_abc(PlaqL, PlaqR, MpInvJx_nu, FdetV, generators, cache)
compute_adj_abc(PlaqL, PlaqR, MpInvJx_nu, FdetV, generators, cache, cb.inv())
t("non-local")
Fdet2_nu += FdetV

Expand All @@ -506,7 +522,7 @@ def gradient(self, U, dU):
MpInvJx_nu = g.cshift(MpInvJx, mu, -1)
MpInvJx_nu = g.cshift(MpInvJx_nu, nu, 1)
t("compute_adj_abc")
compute_adj_abc(PlaqL, PlaqR, MpInvJx_nu, FdetV, generators, cache)
compute_adj_abc(PlaqL, PlaqR, MpInvJx_nu, FdetV, generators, cache, cb)
t("non-local")
Fdet2_nu += FdetV

Expand All @@ -518,7 +534,7 @@ def gradient(self, U, dU):

# mu cw
PlaqL = g((-rho) * csf(U[mu], mu, csb(U[nu], nu, csb(U_mu_masked, mu))))
PlaqR = csb(U[nu], nu)
PlaqR = g(csb(U[nu], nu))

dJdXe_nMpInv_y = g.cshift(dJdXe_nMpInv, nu, -1)
t("compute_adj_ab")
Expand All @@ -528,13 +544,13 @@ def gradient(self, U, dU):

MpInvJx_nu = g.cshift(MpInvJx, nu, -1)
t("compute_adj_abc")
compute_adj_abc(PlaqL, PlaqR, MpInvJx_nu, FdetV, generators, cache)
compute_adj_abc(PlaqL, PlaqR, MpInvJx_nu, FdetV, generators, cache, cb.inv())
t("non-local")
Fdet2_mu += FdetV

# mu acw
PlaqL = g((-rho) * csf(U[mu], mu, csf(U[nu], nu, csb(U_mu_masked, mu))))
PlaqR = csf(U[nu], nu)
PlaqR = g(csf(U[nu], nu))

dJdXe_nMpInv_y = g.cshift(dJdXe_nMpInv, nu, 1)

Expand All @@ -546,7 +562,7 @@ def gradient(self, U, dU):
MpInvJx_nu = g.cshift(MpInvJx, nu, 1)

t("compute_adj_abc")
compute_adj_abc(PlaqL, PlaqR, MpInvJx_nu, FdetV, generators, cache)
compute_adj_abc(PlaqL, PlaqR, MpInvJx_nu, FdetV, generators, cache, cb.inv())
t("non-local")
Fdet2_mu += FdetV

Expand Down

0 comments on commit a686600

Please sign in to comment.